Skip to content

Commit 516ee5f

Browse files
authored
Update p3b1.py
1 parent 1b83853 commit 516ee5f

File tree

1 file changed

+9
-29
lines changed

1 file changed

+9
-29
lines changed

Pilot3/P3B1/p3b1.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import candle_keras as candle
1616

17+
'''
1718
additional_definitions = [
1819
{'name':'train_features',
1920
'action':'store',
@@ -57,16 +58,17 @@
5758
'action':'store',
5859
'type':int}
5960
]
61+
'''
6062

6163

62-
required = ['learning_rate', 'batch_size', 'epochs', 'drop', \
63-
'activation', 'out_activation', 'loss', 'optimizer', 'metrics', \
64-
'n_fold', 'scaling', 'initialization', 'shared_nnet_spec', \
65-
'ind_nnet_spec', 'feature_names']
64+
required = [
65+
'learning_rate', 'batch_size', 'epochs', 'dropout', \
66+
'optimizer', 'wv_len', \
67+
'filter_sizes', 'filter_sets', 'num_filters', 'emb_l2', 'w_l2']
6668

6769

6870

69-
class BenchmarkP3B1(candle.Benchmark):
71+
class BenchmarkP3B3(candle.Benchmark):
7072

7173
def set_locals(self):
7274
"""Functionality to set variables specific for the benchmark
@@ -77,29 +79,7 @@ def set_locals(self):
7779

7880
if required is not None:
7981
self.required = set(required)
80-
if additional_definitions is not None:
81-
self.additional_definitions = additional_definitions
82+
# if additional_definitions is not None:
83+
# self.additional_definitions = additional_definitions
8284

8385

84-
def build_data(nnet_spec_len, fold, data_path):
85-
""" Build feature sets to match the network topology
86-
"""
87-
X_train = []
88-
Y_train = []
89-
90-
X_test = []
91-
Y_test = []
92-
93-
for i in range( nnet_spec_len ):
94-
feature_train = np.genfromtxt(data_path + '/task'+str(i)+'_'+str(fold)+'_train_feature.csv', delimiter= ',' )
95-
label_train = np.genfromtxt(data_path + '/task'+str(i)+'_'+str(fold)+'_train_label.csv', delimiter= ',' )
96-
X_train.append( feature_train )
97-
Y_train.append( label_train )
98-
99-
feature_test = np.genfromtxt(data_path + '/task'+str(i)+'_'+str(fold)+'_test_feature.csv', delimiter= ',' )
100-
label_test = np.genfromtxt(data_path + '/task'+str(i)+'_'+str(fold)+'_test_label.csv', delimiter= ',' )
101-
X_test.append( feature_test )
102-
Y_test.append( label_test )
103-
104-
return X_train, Y_train, X_test, Y_test
105-

0 commit comments

Comments
 (0)