Skip to content

Commit 11c03ba

Browse files
committed
add model saved hook
1 parent 0bc3c1e commit 11c03ba

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

trainer/src/trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,13 @@ class Trainer():
5454
def __init__(self, sync_dir=None, patch_size=572,
5555
max_workers=12,
5656
instruction_deleted_hook=None,
57-
segmentation_created_hook=None):
57+
segmentation_created_hook=None,
58+
model_saved_hook=None):
5859

5960
self.instruction_deleted_hook = instruction_deleted_hook
6061
self.segmentation_created_hook = segmentation_created_hook
62+
self.model_saved_hook = model_saved_hook
63+
6164

6265
valid_sizes = get_valid_patch_sizes()
6366
assert patch_size in valid_sizes, (f'Specified patch size of {patch_size}'
@@ -400,6 +403,10 @@ def validation(self):
400403
cur_metrics['f1'], prev_metrics['f1'])
401404
if was_saved:
402405
self.epochs_without_progress = 0
406+
if self.model_saved_hook:
407+
latest_model_path = model_utils.get_latest_model_paths(model_dir, 1)[0]
408+
self.model_saved_hook(latest_model_path)
409+
403410
else:
404411
self.epochs_without_progress += 1
405412

0 commit comments

Comments
 (0)