22
22
from sklearn .metrics import (
23
23
r2_score ,
24
24
roc_auc_score ,
25
- pearsonr ,
25
+ # pearsonr,
26
26
accuracy_score ,
27
27
)
28
+ from scipy .stats import pearsonr
28
29
29
30
import adrp
30
31
import candle
@@ -179,6 +180,70 @@ def load_cache(cache_file):
179
180
return x_train , y_train , x_val , y_val , x_test , y_test , x_labels , y_labels
180
181
181
182
183
+ def run_inference (params ):
184
+
185
+ if params ['saved_model' ] is not None :
186
+ model_file = params ['saved_model' ]
187
+ else :
188
+ model_file = adrp .get_model (params )
189
+
190
+ print ('Loading model from ' , model_file )
191
+
192
+ # switch based on model type specified
193
+ if model_file .endswith ('.json' ):
194
+ # load json model + weights
195
+ base_model_file = model_file .split ('.json' )
196
+ # load json and create model
197
+ json_file = open (model_file , 'r' )
198
+ loaded_model = json_file .read ()
199
+ json_file .close ()
200
+ loaded_model = model_from_json (loaded_model )
201
+
202
+ # load weights into new model
203
+ loaded_model .load_weights (base_model_file [0 ] + '.h5' )
204
+ print ("Loaded json model from disk" )
205
+ elif model_file .endswith ('.yaml' ):
206
+ # load yaml model + weights
207
+ base_model_file = model_file .split ('.yaml' )
208
+ # load yaml and create model
209
+ yaml_file = open (model_file , 'r' )
210
+ loaded_model = yaml_file .read ()
211
+ yaml_file .close ()
212
+ loaded_model = model_from_yaml (loaded_model )
213
+
214
+ # load weights into new model
215
+ loaded_model .load_weights (base_model_file [0 ] + '.h5' )
216
+ print ("Loaded yaml model from disk" )
217
+ elif model_file .endswith ('.h5' ):
218
+ loaded_model = tf .keras .models .load_model (model_file , compile = False )
219
+ print ("Loaded h5 model from disk" )
220
+ else :
221
+ sys .exit ("Model format should be one of json, yaml or h5" )
222
+
223
+ # compile separately to get custom functions as needed
224
+ loaded_model .compile (optimizer = params ['optimizer' ], loss = params ['loss' ], metrics = ['mae' , r2 ])
225
+
226
+ # use same data as training
227
+ seed = params ['rng_seed' ]
228
+ X_train , Y_train , X_test , Y_test , PS , count_array = adrp .load_data (params , seed )
229
+
230
+ print ("X_train shape:" , X_train .shape )
231
+ print ("X_test shape:" , X_test .shape )
232
+
233
+ print ("Y_train shape:" , Y_train .shape )
234
+ print ("Y_test shape:" , Y_test .shape )
235
+
236
+ score_train = loaded_model .evaluate (X_train , Y_train , verbose = 0 )
237
+
238
+ print ("Training set loss:" , score_train [0 ])
239
+ print ("Training set mae:" , score_train [1 ])
240
+
241
+ score_test = loaded_model .evaluate (X_test , Y_test , verbose = 0 )
242
+
243
+ print ("Validation set loss:" , score_test [0 ])
244
+ print ("Validation set mae:" , score_test [1 ])
245
+
246
+
182
247
def run (params ):
183
248
args = candle .ArgumentStruct (** params )
184
249
seed = args .rng_seed
@@ -451,9 +516,7 @@ def post_process(params, X_train, X_test, Y_test, score, history, model):
451
516
print ("Loaded json model from disk" )
452
517
453
518
# evaluate json loaded model on test data
454
- loaded_model_json .compile (
455
- loss = "binary_crossentropy" , optimizer = "SGD" , metrics = ["mean_absolute_error" ]
456
- )
519
+ loaded_model .compile (optimizer = params ['optimizer' ], loss = params ['loss' ], metrics = ['mae' , r2 ])
457
520
score_json = loaded_model_json .evaluate (X_test , Y_test , verbose = 0 )
458
521
459
522
print ("json Validation loss:" , score_json [0 ])
@@ -466,9 +529,7 @@ def post_process(params, X_train, X_test, Y_test, score, history, model):
466
529
print ("Loaded yaml model from disk" )
467
530
468
531
# evaluate loaded model on test data
469
- loaded_model_yaml .compile (
470
- loss = "binary_crossentropy" , optimizer = "SGD" , metrics = ["mean_absolute_error" ]
471
- )
532
+ loaded_model .compile (optimizer = params ['optimizer' ], loss = params ['loss' ], metrics = ['mae' , r2 ])
472
533
score_yaml = loaded_model_yaml .evaluate (X_test , Y_test , verbose = 0 )
473
534
474
535
print ("yaml Validation loss:" , score_yaml [0 ])
@@ -517,7 +578,10 @@ def post_process(params, X_train, X_test, Y_test, score, history, model):
517
578
518
579
def main ():
519
580
params = initialize_parameters ()
520
- run (params )
581
+ if params ['infer' ] is True :
582
+ run_inference (params )
583
+ else :
584
+ run (params )
521
585
522
586
523
587
if __name__ == "__main__" :
0 commit comments