Skip to content

Commit 3f9bd0c

Browse files
Update regression-example-ames-no-preproc.py
Add negative case test for purge_model_storage.
1 parent d774374 commit 3f9bd0c

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

regression-example-ames-no-preproc.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from cerebros.denseautomlstructuralcomponent.dense_automl_structural_component\
1111
import zero_7_exp_decay, zero_95_exp_decay, simple_sigmoid
1212
from ast import literal_eval
13+
from os import listdir
14+
from os.path import exists
15+
1316

1417
NUMBER_OF_TRAILS_PER_BATCH = 2
1518
NUMBER_OF_BATCHES_OF_TRIALS = 2
@@ -20,15 +23,15 @@
2023

2124
## your data:
2225

26+
META_TRIAL_NUMBER = 1
2327

2428
TIME = pendulum.now().__str__()[:16]\
2529
.replace('T', '_')\
2630
.replace(':', '_')\
2731
.replace('-', '_')
28-
PROJECT_NAME = f'{TIME}_cerebros_auto_ml_test'
29-
32+
PROJECT_NAME = f"{TIME}_cerebros_auto_ml_test"
33+
PROJECT_NAME = f"{PROJECT_NAME}_meta_{META_TRIAL_NUMBER}"
3034

31-
# white = pd.read_csv('wine_data.csv')
3235

3336
raw_data = pd.read_csv('ames.csv')
3437
needed_cols = [
@@ -110,7 +113,7 @@
110113
metrics=[tf.keras.metrics.RootMeanSquaredError()],
111114
epochs=epochs,
112115
patience=7,
113-
project_name=f"{PROJECT_NAME}_meta_{meta_trial_number}",
116+
project_name=PROJECT_NAME,
114117
# use_multiprocessing_for_multiple_neural_networks=False, # pull this param
115118
model_graphs='model_graphs',
116119
batch_size=batch_size,
@@ -121,5 +124,12 @@
121124
best_model_found = cerebros.get_best_model()
122125
print(best_model_found.summary())
123126

127+
# Validate that purge_model_storage is NOT active by default
128+
model_storage_path = f"{PROJECT_MAME}/models"
129+
assert exists(model_storage_path)
130+
num_items = len(listdir(model_storage_path))
131+
print(f"There are {num_items} in {model_storage_path}")
132+
assert num_items > 0
133+
124134
print("result extracted from cerebros")
125135
print(f"Final result was (val_root_mean_squared_error): {result}")

0 commit comments

Comments
 (0)