Skip to content

Commit 8741bb9

Browse files
authored
Update p3b1.py
1 parent 516ee5f commit 8741bb9

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

Pilot3/P3B1/p3b1.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import candle_keras as candle
1616

17-
'''
1817
additional_definitions = [
1918
{'name':'train_features',
2019
'action':'store',
@@ -58,17 +57,16 @@
5857
'action':'store',
5958
'type':int}
6059
]
61-
'''
6260

6361

64-
required = [
65-
'learning_rate', 'batch_size', 'epochs', 'dropout', \
66-
'optimizer', 'wv_len', \
67-
'filter_sizes', 'filter_sets', 'num_filters', 'emb_l2', 'w_l2']
62+
required = ['learning_rate', 'batch_size', 'epochs', 'drop', \
63+
'activation', 'out_activation', 'loss', 'optimizer', 'metrics', \
64+
'n_fold', 'scaling', 'initialization', 'shared_nnet_spec', \
65+
'ind_nnet_spec', 'feature_names']
6866

6967

7068

71-
class BenchmarkP3B3(candle.Benchmark):
69+
class BenchmarkP3B1(candle.Benchmark):
7270

7371
def set_locals(self):
7472
"""Functionality to set variables specific for the benchmark
@@ -79,7 +77,29 @@ def set_locals(self):
7977

8078
if required is not None:
8179
self.required = set(required)
82-
# if additional_definitions is not None:
83-
# self.additional_definitions = additional_definitions
80+
if additional_definitions is not None:
81+
self.additional_definitions = additional_definitions
8482

8583

84+
def build_data(nnet_spec_len, fold, data_path):
85+
""" Build feature sets to match the network topology
86+
"""
87+
X_train = []
88+
Y_train = []
89+
90+
X_test = []
91+
Y_test = []
92+
93+
for i in range( nnet_spec_len ):
94+
feature_train = np.genfromtxt(data_path + '/task'+str(i)+'_'+str(fold)+'_train_feature.csv', delimiter= ',' )
95+
label_train = np.genfromtxt(data_path + '/task'+str(i)+'_'+str(fold)+'_train_label.csv', delimiter= ',' )
96+
X_train.append( feature_train )
97+
Y_train.append( label_train )
98+
99+
feature_test = np.genfromtxt(data_path + '/task'+str(i)+'_'+str(fold)+'_test_feature.csv', delimiter= ',' )
100+
label_test = np.genfromtxt(data_path + '/task'+str(i)+'_'+str(fold)+'_test_label.csv', delimiter= ',' )
101+
X_test.append( feature_test )
102+
Y_test.append( label_test )
103+
104+
return X_train, Y_train, X_test, Y_test
105+

0 commit comments

Comments
 (0)