|
10 | 10 | from cerebros.denseautomlstructuralcomponent.dense_automl_structural_component\ |
11 | 11 | import zero_7_exp_decay, zero_95_exp_decay, simple_sigmoid |
12 | 12 | from ast import literal_eval |
| 13 | +from os import listdir |
13 | 14 |
|
14 | 15 | NUMBER_OF_TRAILS_PER_BATCH = 2 |
15 | 16 | NUMBER_OF_BATCHES_OF_TRIALS = 2 |
16 | 17 |
|
| 18 | +META_TRIAL_NUMBER = 1 |
| 19 | + |
17 | 20 | ### |
18 | 21 |
|
19 | 22 | LABEL_COLUMN = 'price' |
|
23 | 26 | .replace('T', '_')\ |
24 | 27 | .replace(':', '_')\ |
25 | 28 | .replace('-', '_') |
26 | | -PROJECT_NAME = f'{TIME}_cerebros_auto_ml_test' |
| 29 | +PROJECT_NAME = f'{TIME}_cerebros_auto_ml_test-{META_TRIAL_NUMBER }' |
| 30 | +PROJECT_NAME = f"{PROJECT_NAME}_meta_{meta_trial_number}" |
27 | 31 |
|
28 | 32 | def hash_a_row(row): |
29 | 33 | """casts a row of a Pandas DataFrame as a String, hashes it, and casts it |
@@ -207,16 +211,22 @@ def hash_based_split(df, # Pandas dataframe |
207 | 211 | metrics=[tf.keras.metrics.RootMeanSquaredError()], |
208 | 212 | epochs=epochs, |
209 | 213 | patience=7, |
210 | | - project_name=f"{PROJECT_NAME}_meta_{meta_trial_number}", |
| 214 | + project_name=PROJECT_NAME, |
211 | 215 | # use_multiprocessing_for_multiple_neural_networks=False, # pull this param |
212 | 216 | model_graphs='model_graphs', |
213 | 217 | batch_size=batch_size, |
214 | 218 | meta_trial_number=meta_trial_number) |
215 | 219 | result = cerebros.run_random_search() |
216 | 220 |
|
217 | 221 | print("Best model: (May need to re-initialize weights, and retrain with early stopping callback)") |
218 | | -best_model_found = cerebros.get_best_model() |
| 222 | +best_model_found = cerebros.get_best_model(purge_model_storage_files=True) |
219 | 223 | print(best_model_found.summary()) |
220 | 224 |
|
| 225 | +# Verify purge_model_storage_files works: |
| 226 | +model_storage_path = f"{PROJECT_NAME}/models" |
| 227 | +num_items = len(listdir(model_storage_path)) |
| 228 | +print(f"There are {num_items} items in {model_storage_path}") |
| 229 | +assert num_items == 0 |
| 230 | + |
221 | 231 | print("result extracted from cerebros") |
222 | 232 | print(f"Final result was (val_root_mean_squared_error): {result}") |
0 commit comments