File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed
Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -54,10 +54,13 @@ class Trainer():
5454 def __init__ (self , sync_dir = None , patch_size = 572 ,
5555 max_workers = 12 ,
5656 instruction_deleted_hook = None ,
57- segmentation_created_hook = None ):
57+ segmentation_created_hook = None ,
58+ model_saved_hook = None ):
5859
5960 self .instruction_deleted_hook = instruction_deleted_hook
6061 self .segmentation_created_hook = segmentation_created_hook
62+ self .model_saved_hook = model_saved_hook
63+
6164
6265 valid_sizes = get_valid_patch_sizes ()
6366 assert patch_size in valid_sizes , (f'Specified patch size of { patch_size } '
@@ -400,6 +403,10 @@ def validation(self):
400403 cur_metrics ['f1' ], prev_metrics ['f1' ])
401404 if was_saved :
402405 self .epochs_without_progress = 0
406+ if self .model_saved_hook :
407+ latest_model_path = model_utils .get_latest_model_paths (model_dir , 1 )[0 ]
408+ self .model_saved_hook (latest_model_path )
409+
403410 else :
404411 self .epochs_without_progress += 1
405412
You can’t perform that action at this time.
0 commit comments