Skip to content

Commit f0d17aa

Browse files
authored
[Feature] Support trial name template (#146)
1 parent d95ddb2 commit f0d17aa

File tree

4 files changed

+52
-2
lines changed

4 files changed

+52
-2
lines changed

siatune/tune/tuner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from siatune.codebase import build_task
1414
from siatune.tune import (build_callback, build_scheduler, build_searcher,
1515
build_space, build_stopper)
16+
from .utils import NAME_CREATOR
1617

1718

1819
class Tuner:
@@ -106,8 +107,10 @@ def __init__(
106107
tune_config=TuneConfig(
107108
search_alg=searcher,
108109
scheduler=trial_scheduler,
109-
trial_name_creator=lambda trial: trial.trial_id,
110-
trial_dirname_creator=lambda trial: trial.experiment_tag,
110+
trial_name_creator=NAME_CREATOR.get(
111+
tune_cfg.pop('trial_name_creator', 'trial_id')),
112+
trial_dirname_creator=NAME_CREATOR.get(
113+
tune_cfg.pop('trial_dirname_creator', 'experiment_tag')),
111114
**tune_cfg),
112115
run_config=RunConfig(
113116
name=self.experiment_name,

siatune/tune/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) SI-Analytics. All rights reserved.
2+
from .name_creator import NAME_CREATOR, experiment_tag, trial_id
3+
4+
__all__ = ['NAME_CREATOR', 'experiment_tag', 'trial_id']

siatune/tune/utils/name_creator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) SI-Analytics. All rights reserved.
2+
from mmengine.registry import Registry
3+
from ray.tune.experiment import Trial
4+
5+
NAME_CREATOR = Registry('name creator')
6+
7+
8+
@NAME_CREATOR.register_module()
9+
def trial_id(trial: Trial) -> str:
10+
return trial.trial_id
11+
12+
13+
@NAME_CREATOR.register_module()
14+
def experiment_tag(trial: Trial) -> str:
15+
return trial.experiment_tag

tests/test_tune/test_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import inspect
2+
from unittest.mock import MagicMock
3+
4+
import pytest
5+
6+
from siatune.tune.utils import NAME_CREATOR
7+
8+
9+
@pytest.fixture
10+
def trial():
11+
return MagicMock(
12+
trainable_name='test',
13+
trial_id='trial_id',
14+
experiment_tag='experiment_tag')
15+
16+
17+
def test_trial_id(trial):
18+
tmpl = NAME_CREATOR.get('trial_id')
19+
assert inspect.isfunction(tmpl)
20+
assert tmpl.__name__ == 'trial_id'
21+
assert tmpl(trial) == trial.trial_id
22+
23+
24+
def test_experiment_tag(trial):
25+
tmpl = NAME_CREATOR.get('experiment_tag')
26+
assert inspect.isfunction(tmpl)
27+
assert tmpl.__name__ == 'experiment_tag'
28+
assert tmpl(trial) == trial.experiment_tag

0 commit comments

Comments
 (0)