Skip to content

Commit dd6b459

Browse files
committed
Added inference method and ability to download saved models from server as well as use local models. Fixed some cases of model not respecting parameters.
1 parent 7018211 commit dd6b459

File tree

3 files changed

+103
-10
lines changed

3 files changed

+103
-10
lines changed

examples/ADRP/adrp.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,24 @@
112112
"default": "ADRP_6W02_A_1_H",
113113
"help": "base name of pocket",
114114
},
115-
115+
{
116+
"name": "saved_model",
117+
"type": str,
118+
"default": None,
119+
"help": "Saved model to test",
120+
},
121+
{
122+
"name": "model_url",
123+
"type": str,
124+
"default": None,
125+
"help": "Url for saved models to test",
126+
},
127+
{
128+
"name": "infer",
129+
"type": candle.str2bool,
130+
"default": False,
131+
"help": "Flag to toggle inference mode",
132+
},
116133
]
117134

118135
required = [
@@ -206,6 +223,17 @@ def load_headers(desc_headers, train_headers, header_url):
206223
return dh_dict, th_list
207224

208225

226+
def get_model(params):
227+
url = params['model_url']
228+
file_model = ('DIR.ml.' + params['base_name']
229+
+ '.Orderable_zinc_db_enaHLL.sorted.4col.dd.parquet/'
230+
+ 'reg_go.autosave.model.h5')
231+
model_file = candle.get_file(
232+
file_model, url + file_model, cache_subdir="Pilot1"
233+
)
234+
return model_file
235+
236+
209237
def load_data(params, seed):
210238
header_url = params["header_url"]
211239
dh_dict, th_list = load_headers('descriptor_headers.csv', 'training_headers.csv', header_url)

examples/ADRP/adrp_baseline_keras2.py

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
from sklearn.metrics import (
2323
r2_score,
2424
roc_auc_score,
25-
pearsonr,
25+
# pearsonr,
2626
accuracy_score,
2727
)
28+
from scipy.stats import pearsonr
2829

2930
import adrp
3031
import candle
@@ -179,6 +180,70 @@ def load_cache(cache_file):
179180
return x_train, y_train, x_val, y_val, x_test, y_test, x_labels, y_labels
180181

181182

183+
def run_inference(params):
184+
185+
if params['saved_model'] is not None:
186+
model_file = params['saved_model']
187+
else:
188+
model_file = adrp.get_model(params)
189+
190+
print('Loading model from ', model_file)
191+
192+
# switch based on model type specified
193+
if model_file.endswith('.json'):
194+
# load json model + weights
195+
base_model_file = model_file.split('.json')
196+
# load json and create model
197+
json_file = open(model_file, 'r')
198+
loaded_model = json_file.read()
199+
json_file.close()
200+
loaded_model = model_from_json(loaded_model)
201+
202+
# load weights into new model
203+
loaded_model.load_weights(base_model_file[0] + '.h5')
204+
print("Loaded json model from disk")
205+
elif model_file.endswith('.yaml'):
206+
# load yaml model + weights
207+
base_model_file = model_file.split('.yaml')
208+
# load yaml and create model
209+
yaml_file = open(model_file, 'r')
210+
loaded_model = yaml_file.read()
211+
yaml_file.close()
212+
loaded_model = model_from_yaml(loaded_model)
213+
214+
# load weights into new model
215+
loaded_model.load_weights(base_model_file[0] + '.h5')
216+
print("Loaded yaml model from disk")
217+
elif model_file.endswith('.h5'):
218+
loaded_model = tf.keras.models.load_model(model_file, compile=False)
219+
print("Loaded h5 model from disk")
220+
else:
221+
sys.exit("Model format should be one of json, yaml or h5")
222+
223+
# compile separately to get custom functions as needed
224+
loaded_model.compile(optimizer = params['optimizer'], loss = params['loss'], metrics = ['mae', r2])
225+
226+
# use same data as training
227+
seed = params['rng_seed']
228+
X_train, Y_train, X_test, Y_test, PS, count_array = adrp.load_data(params, seed)
229+
230+
print("X_train shape:", X_train.shape)
231+
print("X_test shape:", X_test.shape)
232+
233+
print("Y_train shape:", Y_train.shape)
234+
print("Y_test shape:", Y_test.shape)
235+
236+
score_train = loaded_model.evaluate(X_train, Y_train, verbose=0)
237+
238+
print("Training set loss:", score_train[0])
239+
print("Training set mae:", score_train[1])
240+
241+
score_test = loaded_model.evaluate(X_test, Y_test, verbose=0)
242+
243+
print("Validation set loss:", score_test[0])
244+
print("Validation set mae:", score_test[1])
245+
246+
182247
def run(params):
183248
args = candle.ArgumentStruct(**params)
184249
seed = args.rng_seed
@@ -451,9 +516,7 @@ def post_process(params, X_train, X_test, Y_test, score, history, model):
451516
print("Loaded json model from disk")
452517

453518
# evaluate json loaded model on test data
454-
loaded_model_json.compile(
455-
loss="binary_crossentropy", optimizer="SGD", metrics=["mean_absolute_error"]
456-
)
519+
loaded_model.compile(optimizer = params['optimizer'], loss = params['loss'], metrics = ['mae', r2])
457520
score_json = loaded_model_json.evaluate(X_test, Y_test, verbose=0)
458521

459522
print("json Validation loss:", score_json[0])
@@ -466,9 +529,7 @@ def post_process(params, X_train, X_test, Y_test, score, history, model):
466529
print("Loaded yaml model from disk")
467530

468531
# evaluate loaded model on test data
469-
loaded_model_yaml.compile(
470-
loss="binary_crossentropy", optimizer="SGD", metrics=["mean_absolute_error"]
471-
)
532+
loaded_model.compile(optimizer = params['optimizer'], loss = params['loss'], metrics = ['mae', r2])
472533
score_yaml = loaded_model_yaml.evaluate(X_test, Y_test, verbose=0)
473534

474535
print("yaml Validation loss:", score_yaml[0])
@@ -517,7 +578,10 @@ def post_process(params, X_train, X_test, Y_test, score, history, model):
517578

518579
def main():
519580
params = initialize_parameters()
520-
run(params)
581+
if params['infer'] is True:
582+
run_inference(params)
583+
else:
584+
run(params)
521585

522586

523587
if __name__ == "__main__":

examples/ADRP/adrp_default_model.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[Global_Params]
22
header_url = 'https://raw.githubusercontent.com/brettin/ML-training-inferencing/master/'
33
data_url = 'ftp://ftp.mcs.anl.gov/pub/candle/public/benchmarks/Examples/V5.1-1M-flatten/'
4+
model_url = 'ftp://ftp.mcs.anl.gov/pub/candle/public/models/examples/adrp/V5.1-ml-models-1M-flatten.release/'
45
train_data = ''
56
base_name = 'ADRP_6W02_A_1_H'
67
model_name = 'adrp'
@@ -10,7 +11,7 @@ epochs = 400
1011
activation = 'elu'
1112
out_activation = 'relu'
1213
loss = 'mean_squared_error'
13-
optimizer = 'adam'
14+
optimizer = 'sgd'
1415
dropout = 0.1
1516
learning_rate = 0.0001
1617
momentum = 0.9

0 commit comments

Comments
 (0)