14
14
15
15
import candle_keras as candle
16
16
17
- '''
18
17
additional_definitions = [
19
18
{'name' :'train_features' ,
20
19
'action' :'store' ,
58
57
'action' :'store' ,
59
58
'type' :int }
60
59
]
61
- '''
62
60
63
61
64
- required = [
65
- 'learning_rate ' , 'batch_size ' , 'epochs ' , 'dropout ' , \
66
- 'optimizer ' , 'wv_len ' , \
67
- 'filter_sizes ' , 'filter_sets' , 'num_filters' , 'emb_l2' , 'w_l2 ' ]
62
+ required = ['learning_rate' , 'batch_size' , 'epochs' , 'drop' , \
63
+ 'activation ' , 'out_activation ' , 'loss ' , 'optimizer' , 'metrics ' , \
64
+ 'n_fold ' , 'scaling' , 'initialization' , 'shared_nnet_spec ' , \
65
+ 'ind_nnet_spec ' , 'feature_names ' ]
68
66
69
67
70
68
71
- class BenchmarkP3B3 (candle .Benchmark ):
69
+ class BenchmarkP3B1 (candle .Benchmark ):
72
70
73
71
def set_locals (self ):
74
72
"""Functionality to set variables specific for the benchmark
@@ -79,7 +77,29 @@ def set_locals(self):
79
77
80
78
if required is not None :
81
79
self .required = set (required )
82
- # if additional_definitions is not None:
83
- # self.additional_definitions = additional_definitions
80
+ if additional_definitions is not None :
81
+ self .additional_definitions = additional_definitions
84
82
85
83
84
+ def build_data (nnet_spec_len , fold , data_path ):
85
+ """ Build feature sets to match the network topology
86
+ """
87
+ X_train = []
88
+ Y_train = []
89
+
90
+ X_test = []
91
+ Y_test = []
92
+
93
+ for i in range ( nnet_spec_len ):
94
+ feature_train = np .genfromtxt (data_path + '/task' + str (i )+ '_' + str (fold )+ '_train_feature.csv' , delimiter = ',' )
95
+ label_train = np .genfromtxt (data_path + '/task' + str (i )+ '_' + str (fold )+ '_train_label.csv' , delimiter = ',' )
96
+ X_train .append ( feature_train )
97
+ Y_train .append ( label_train )
98
+
99
+ feature_test = np .genfromtxt (data_path + '/task' + str (i )+ '_' + str (fold )+ '_test_feature.csv' , delimiter = ',' )
100
+ label_test = np .genfromtxt (data_path + '/task' + str (i )+ '_' + str (fold )+ '_test_label.csv' , delimiter = ',' )
101
+ X_test .append ( feature_test )
102
+ Y_test .append ( label_test )
103
+
104
+ return X_train , Y_train , X_test , Y_test
105
+
0 commit comments