Skip to content

Commit 128154f

Browse files
authored
[FIX] Handle unpickable update_progress_callback (#1892)
* fix: get rid of unpikcable object * fix: separte pickle and deepcopy * fix: align type hint * test: fix gt * fix: consider inherited classes
1 parent bcd3109 commit 128154f

File tree

3 files changed

+39
-7
lines changed

3 files changed

+39
-7
lines changed

otx/api/entities/train_parameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class UpdateProgressCallback(Protocol):
1414
`score: Optional[float] = None`
1515
"""
1616

17-
def __call__(self, progress: int, score: Optional[float] = None):
17+
def __call__(self, progress: float, score: Optional[float] = None):
1818
"""Callback to provide updates about the progress of a task.
1919
2020
It is recommended to call this function at least once per epoch.
@@ -30,7 +30,7 @@ def __call__(self, progress: int, score: Optional[float] = None):
3030

3131

3232
# pylint: disable=unused-argument
33-
def default_progress_callback(progress: int, score: Optional[float] = None):
33+
def default_progress_callback(progress: float, score: Optional[float] = None):
3434
"""Default progress callback. It is a placeholder (does nothing) and is used in empty TrainParameters."""
3535

3636

otx/api/usecases/reporting/time_monitor_callback.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,15 @@
99
import logging
1010
import math
1111
import time
12-
from typing import List, Optional
12+
from copy import deepcopy
13+
from typing import List
1314

14-
from otx.api.entities.train_parameters import UpdateProgressCallback
15+
import dill
16+
17+
from otx.api.entities.train_parameters import (
18+
UpdateProgressCallback,
19+
default_progress_callback,
20+
)
1521
from otx.api.usecases.reporting.callback import Callback
1622

1723
logger = logging.getLogger(__name__)
@@ -27,7 +33,7 @@ class TimeMonitorCallback(Callback):
2733
num_test_steps (int): amount of testing steps
2834
epoch_history (int): Amount of previous epochs to calculate average epoch time over
2935
step_history (int): Amount of previous steps to calculate average steps time over
30-
update_progress_callback (Optional[UpdateProgressCallback]): Callback to update progress
36+
update_progress_callback (UpdateProgressCallback): Callback to update progress
3137
"""
3238

3339
def __init__(
@@ -38,7 +44,7 @@ def __init__(
3844
num_test_steps: int = 0,
3945
epoch_history: int = 5,
4046
step_history: int = 50,
41-
update_progress_callback: Optional[UpdateProgressCallback] = None,
47+
update_progress_callback: UpdateProgressCallback = default_progress_callback,
4248
):
4349

4450
self.total_epochs = num_epoch
@@ -67,6 +73,32 @@ def __init__(
6773

6874
self.update_progress_callback = update_progress_callback
6975

76+
def __getstate__(self):
77+
"""Return state values to be pickled."""
78+
state = self.__dict__.copy()
79+
# update_progress_callback is not always pickable object
80+
# if it is not, replace it with default callback
81+
if not dill.pickles(state["update_progress_callback"]):
82+
state["update_progress_callback"] = default_progress_callback
83+
return state
84+
85+
def __deepcopy__(self, memo):
86+
"""Return deepcopy object."""
87+
88+
update_progress_callback = self.update_progress_callback
89+
self.update_progress_callback = None
90+
self.__dict__["__deepcopy__"] = None
91+
92+
result = deepcopy(self, memo)
93+
94+
self.__dict__.pop("__deepcopy__")
95+
result.__dict__.pop("__deepcopy__")
96+
result.update_progress_callback = update_progress_callback
97+
self.update_progress_callback = update_progress_callback
98+
99+
memo[id(self)] = result
100+
return result
101+
70102
def on_train_batch_begin(self, batch, logs=None):
71103
"""Set the value of current step and start the timer."""
72104
self.current_step += 1

tests/unit/api/usecases/reporting/test_time_monitor_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def check_time_monitor_callback_attributes(
9393
actual_time_monitor_callback=time_monitor_callback,
9494
expected_epoch_history=5,
9595
expected_step_history=50,
96-
expected_update_progress_callback=None,
96+
expected_update_progress_callback=default_progress_callback,
9797
)
9898
# Checking attributes of "TimeMonitorCallback" initialized with specified optional parameters
9999
step_history = 10

0 commit comments

Comments
 (0)