Skip to content

Commit a3c9fc6

Browse files
authored
Update HPO interface (#4035)
* update hpo interface * update unit test * update CHANGELOG.md
1 parent 91d2df2 commit a3c9fc6

File tree

4 files changed

+52
-11
lines changed

4 files changed

+52
-11
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ All notable changes to this project will be documented in this file.
6363
(<https://github.com/openvinotoolkit/training_extensions/pull/4011>)
6464
- Prevent using too low confidence thresholds in detection
6565
(<https://github.com/openvinotoolkit/training_extensions/pull/4018>)
66+
- Update HPO interface
67+
(<https://github.com/openvinotoolkit/training_extensions/pull/4035>)
6668

6769
### Bug fixes
6870

src/otx/core/config/hpo.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from dataclasses import dataclass
99
from pathlib import Path # noqa: TCH003
10-
from typing import Any, Literal
10+
from typing import Any, Callable, Literal
1111

1212
import torch
1313

@@ -23,7 +23,12 @@
2323

2424
@dataclass
2525
class HpoConfig:
26-
"""DTO for HPO configuration."""
26+
"""DTO for HPO configuration.
27+
28+
progress_update_callback (Callable[[int | float], None] | None):
29+
callback to update progress. If it's given, it's called with progress every second.
30+
callbacks_to_exclude (list[str] | str | None): List of name of callbacks to exclude during HPO.
31+
"""
2732

2833
search_space: dict[str, dict[str, Any]] | str | Path | None = None
2934
save_path: str | None = None
@@ -40,3 +45,5 @@ class HpoConfig:
4045
asynchronous_sha: bool = num_workers > 1
4146
metric_name: str | None = None
4247
adapt_bs_search_space_max_val: Literal["None", "Safe", "Full"] = "None"
48+
progress_update_callback: Callable[[int | float], None] | None = None
49+
callbacks_to_exclude: list[str] | str | None = None

src/otx/engine/hpo/hpo_api.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
import json
1010
import logging
1111
import time
12+
from copy import copy
1213
from functools import partial
1314
from pathlib import Path
1415
from threading import Thread
1516
from typing import TYPE_CHECKING, Any, Callable, Literal
1617

1718
import torch
1819
import yaml
20+
from lightning import Callback
1921

2022
from otx.core.config.hpo import HpoConfig
2123
from otx.core.optimizer.callable import OptimizerCallableSupportHPO
@@ -35,7 +37,6 @@
3537
from .utils import find_trial_file, get_best_hpo_weight, get_callable_args_name, get_hpo_weight_dir, get_metric
3638

3739
if TYPE_CHECKING:
38-
from lightning import Callback
3940
from lightning.pytorch.cli import OptimizerCallable
4041

4142
from otx.engine.engine import Engine
@@ -48,7 +49,6 @@ def execute_hpo(
4849
engine: Engine,
4950
max_epochs: int,
5051
hpo_config: HpoConfig,
51-
progress_update_callback: Callable[[int | float], None] | None = None,
5252
callbacks: list[Callback] | Callback | None = None,
5353
**train_args,
5454
) -> tuple[dict[str, Any] | None, Path | None]:
@@ -58,8 +58,6 @@ def execute_hpo(
5858
engine (Engine): engine instnace.
5959
max_epochs (int): max epochs to train.
6060
hpo_config (HpoConfig): Configuration for HPO.
61-
progress_update_callback (Callable[[int | float], None] | None, optional):
62-
callback to update progress. If it's given, it's called with progress every second. Defaults to None.
6361
callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None.
6462
6563
Returns:
@@ -97,8 +95,23 @@ def execute_hpo(
9795
logger.warning("HPO is skipped.")
9896
return None, None
9997

100-
if progress_update_callback is not None:
101-
Thread(target=_update_hpo_progress, args=[progress_update_callback, hpo_algo], daemon=True).start()
98+
if hpo_config.progress_update_callback is not None:
99+
Thread(target=_update_hpo_progress, args=[hpo_config.progress_update_callback, hpo_algo], daemon=True).start()
100+
101+
if hpo_config.callbacks_to_exclude is not None and callbacks is not None:
102+
if isinstance(hpo_config.callbacks_to_exclude, str):
103+
hpo_config.callbacks_to_exclude = [hpo_config.callbacks_to_exclude]
104+
if isinstance(callbacks, Callback):
105+
callbacks = [callbacks]
106+
107+
callbacks = copy(callbacks)
108+
callback_names = [callback.__class__.__name__ for callback in callbacks]
109+
callback_idx_to_exclude = [
110+
callback_names.index(cb_name) for cb_name in hpo_config.callbacks_to_exclude if cb_name in callback_names
111+
]
112+
sorted(callback_idx_to_exclude, reverse=True)
113+
for idx in callback_idx_to_exclude:
114+
callbacks.pop(idx)
102115

103116
run_hpo_loop(
104117
hpo_algo,

tests/unit/engine/hpo/test_hpo_api.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,27 @@ def mock_find_trial_file(mocker) -> MagicMock:
119119

120120
@pytest.fixture()
121121
def hpo_config() -> HpoConfig:
122-
return HpoConfig(metric_name="val/accuracy")
122+
return HpoConfig(metric_name="val/accuracy", callbacks_to_exclude="UselessCallback")
123123

124124

125125
@pytest.fixture()
126126
def mock_progress_update_callback() -> MagicMock:
127127
return MagicMock()
128128

129129

130+
class UsefullCallback:
131+
pass
132+
133+
134+
class UselessCallback:
135+
pass
136+
137+
138+
@pytest.fixture()
139+
def mock_callback() -> list:
140+
return [UsefullCallback(), UselessCallback()]
141+
142+
130143
def test_execute_hpo(
131144
mock_engine: MagicMock,
132145
hpo_config: HpoConfig,
@@ -138,12 +151,14 @@ def test_execute_hpo(
138151
mock_get_best_hpo_weight: MagicMock,
139152
mock_find_trial_file: MagicMock,
140153
mock_progress_update_callback: MagicMock,
154+
mock_callback: list,
141155
):
156+
hpo_config.progress_update_callback = mock_progress_update_callback
142157
best_config, best_hpo_weight = execute_hpo(
143158
engine=mock_engine,
144159
max_epochs=10,
145160
hpo_config=hpo_config,
146-
progress_update_callback=mock_progress_update_callback,
161+
callbacks=mock_callback,
147162
)
148163

149164
# check hpo workdir exists
@@ -152,12 +167,16 @@ def test_execute_hpo(
152167
# check a case where progress_update_callback exists
153168
mock_thread.assert_called_once()
154169
assert mock_thread.call_args.kwargs["target"] == _update_hpo_progress
155-
assert mock_thread.call_args.kwargs["args"][0] == mock_progress_update_callback
156170
assert mock_thread.call_args.kwargs["daemon"] is True
157171
mock_thread.return_value.start.assert_called_once()
158172
# check whether run_hpo_loop is called well
159173
mock_run_hpo_loop.assert_called_once()
160174
assert mock_run_hpo_loop.call_args.args[0] == mock_hpo_algo
175+
# check UselessCallback is excluded
176+
for callback in mock_run_hpo_loop.call_args.args[1].keywords["callbacks"]:
177+
assert not isinstance(callback, UselessCallback)
178+
# check origincal callback lists isn't changed.
179+
assert len(mock_callback) == 2
161180
# print_result is called after HPO is done
162181
mock_hpo_algo.print_result.assert_called_once()
163182
# best_config and best_hpo_weight are returned well

0 commit comments

Comments
 (0)