99import logging
1010import math
1111import time
12- from typing import List , Optional
12+ from copy import deepcopy
13+ from typing import List
1314
14- from otx .api .entities .train_parameters import UpdateProgressCallback
15+ import dill
16+
17+ from otx .api .entities .train_parameters import (
18+ UpdateProgressCallback ,
19+ default_progress_callback ,
20+ )
1521from otx .api .usecases .reporting .callback import Callback
1622
1723logger = logging .getLogger (__name__ )
@@ -27,7 +33,7 @@ class TimeMonitorCallback(Callback):
2733 num_test_steps (int): amount of testing steps
2834 epoch_history (int): Amount of previous epochs to calculate average epoch time over
2935 step_history (int): Amount of previous steps to calculate average steps time over
30- update_progress_callback (Optional[ UpdateProgressCallback] ): Callback to update progress
36+ update_progress_callback (UpdateProgressCallback): Callback to update progress
3137 """
3238
3339 def __init__ (
@@ -38,7 +44,7 @@ def __init__(
3844 num_test_steps : int = 0 ,
3945 epoch_history : int = 5 ,
4046 step_history : int = 50 ,
41- update_progress_callback : Optional [ UpdateProgressCallback ] = None ,
47+ update_progress_callback : UpdateProgressCallback = default_progress_callback ,
4248 ):
4349
4450 self .total_epochs = num_epoch
@@ -67,6 +73,32 @@ def __init__(
6773
6874 self .update_progress_callback = update_progress_callback
6975
76+ def __getstate__ (self ):
77+ """Return state values to be pickled."""
78+ state = self .__dict__ .copy ()
79+ # update_progress_callback is not always pickable object
80+ # if it is not, replace it with default callback
81+ if not dill .pickles (state ["update_progress_callback" ]):
82+ state ["update_progress_callback" ] = default_progress_callback
83+ return state
84+
85+ def __deepcopy__ (self , memo ):
86+ """Return deepcopy object."""
87+
88+ update_progress_callback = self .update_progress_callback
89+ self .update_progress_callback = None
90+ self .__dict__ ["__deepcopy__" ] = None
91+
92+ result = deepcopy (self , memo )
93+
94+ self .__dict__ .pop ("__deepcopy__" )
95+ result .__dict__ .pop ("__deepcopy__" )
96+ result .update_progress_callback = update_progress_callback
97+ self .update_progress_callback = update_progress_callback
98+
99+ memo [id (self )] = result
100+ return result
101+
70102 def on_train_batch_begin (self , batch , logs = None ):
71103 """Set the value of current step and start the timer."""
72104 self .current_step += 1
0 commit comments