Skip to content

Commit 47dae8b

Browse files
Update simple_cerebros_random_search.py
Made model purging more flexible to accommodate preservation of the best model.
1 parent fcc2efd commit 47dae8b

File tree

1 file changed

+54
-5
lines changed

1 file changed

+54
-5
lines changed

cerebros/simplecerebrosrandomsearch/simple_cerebros_random_search.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -600,13 +600,62 @@ def has_valid_metric(num):
600600
return best
601601

602602
def purge_model_storage(self):
603-
path_0 = f"{self.project_name}/models"
604-
rmtree(path_0)
603+
"""Slates all cached models.
604+
Recommended when running in a container without a mounted volume.
605+
It is recommened to use an artifiact registry to accession the best model.
606+
"""
607+
model_cache_path = f"{self.project_name}/models"
608+
rmtree(model_cache_path)
605609

606-
def get_best_model(self, purge_model_storage_files: bool=False):
610+
611+
def purge_models_except_best_model(self)
612+
"""
613+
Recommended when running in a container without a mounted volume and building models that take considerable time to reproduce.
614+
It is recommened to use an artifiact registry to accession the best model, but this will preserve a redundant
615+
copy in case accessioning it to a registry is unsuccessful.
616+
"""
617+
if not self.best_model_path:
618+
return ValueError("The function purge_models_except_best_model was called prematurely: self.best_model_path is not set, maining there is no 'Best model'.")
619+
model_cache_path = f"{self.project_name}/models"
620+
files_path_obj = os.listdir(model_cache_path)
621+
files_str = [str(p) for p in files_path_obj]
622+
print("Files in model cache:")
623+
for file in files_str:
624+
file_path = f"{model_cache_path}/{file}"
625+
print(f" {model_file_path}")
626+
if file_path != self.best_model_path
627+
print(f"Removing: {file_path}")
628+
os.remove(file_path)
629+
# Temp debug code:
630+
else:
631+
print(f"Not removing {file_path}")
632+
633+
634+
def get_best_model(self, purge_model_storage_files=0) -> tf.keras.Model:
635+
"""Returns the best model from this meta-trial.
636+
Optionally, purges cache of models stored on disk.
637+
638+
Params:
639+
- purge_model_storage_files Union[str, int]
640+
- Set to 0: Does not purge the cached modelsl, just returns the best model.
641+
- Set to 1: Purges all models except the best model found.
642+
- Set to "slate": Removes all models, whether the best or otherwise.
643+
When running ephemeral trials in a container without a mounted volume (to prevent
644+
memory pressure accumulating from ephemeral files in memory) or are otherwise working
645+
with hard disk space limitations, we recommend setting this:
646+
- 'slate': if you are working on models that are quick to reproduce and an accidental model loss is not problematic as long as you have the parameters to reproduce it approximately.
647+
- 1: If you are are workign on models that take considerable time to reproduce a given model or a small performance difference from another model from the same parameters is problematic.
648+
- 0 If you have unlimited disk space and are not in a container or in one with a suitable mounted volume.
649+
"""
607650
best_model = tf.keras.models.load_model(self.best_model_path)
608-
if purge_model_storage_files:
609-
self.purge_model_storage()
651+
if purge_model_storage_files == 1:
652+
self.purge_models_except_best_model()
653+
elif purge_model_storage_files == "slate":
654+
self.purge_model_storage()
655+
elif purge_model_storage_files == 0
656+
pass
657+
else:
658+
raise ValueError("The paramerter purge_model_storage_files in the method get_best_model() has 3 values: 0 (Don't purge),1 (Purge all but the best model), 'slate' (remove all cached models) ")
610659
return best_model
611660

612661
# ->

0 commit comments

Comments
 (0)