Skip to content

Commit e721eab

Browse files
authored
Support custom monitor (#66)
1 parent fefbbee commit e721eab

File tree

6 files changed

+37
-34
lines changed

6 files changed

+37
-34
lines changed

tests/explorer/explorer_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
get_unittest_dataset_config,
1313
)
1414
from trinity.cli.launcher import explore
15-
from trinity.common.constants import MonitorType
1615

1716

1817
class BaseExplorerCase(RayUnittestBase):
@@ -23,7 +22,7 @@ def setUp(self):
2322
self.config.model.model_path = get_model_path()
2423
self.config.explorer.rollout_model.engine_type = "vllm_async"
2524
self.config.algorithm.repeat_times = 2
26-
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
25+
self.config.monitor.monitor_type = "tensorboard"
2726
self.config.project = "Trinity-unittest"
2827
self.config.checkpoint_root_dir = get_checkpoint_path()
2928
self.config.synchronizer.sync_interval = 2

tests/trainer/trainer_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
get_unittest_dataset_config,
1616
)
1717
from trinity.cli.launcher import bench, both
18-
from trinity.common.constants import MonitorType, SyncMethod
18+
from trinity.common.constants import SyncMethod
1919

2020

2121
class BaseTrainerCase(RayUnittestBase):
@@ -30,7 +30,7 @@ def setUp(self):
3030
self.config.explorer.rollout_model.use_v1 = False
3131
self.config.project = "Trainer-unittest"
3232
self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
33-
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
33+
self.config.monitor.monitor_type = "tensorboard"
3434
self.config.checkpoint_root_dir = get_checkpoint_path()
3535
self.config.synchronizer.sync_interval = 2
3636
self.config.synchronizer.sync_method = SyncMethod.NCCL

trinity/common/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from trinity.common.constants import (
1010
AlgorithmType,
11-
MonitorType,
1211
PromptType,
1312
ReadStrategy,
1413
StorageType,
@@ -278,7 +277,9 @@ class TrainerConfig:
278277
@dataclass
279278
class MonitorConfig:
280279
# TODO: support multiple monitors (List[MonitorType])
281-
monitor_type: MonitorType = MonitorType.WANDB
280+
monitor_type: str = "tensorboard"
281+
# the default args for monitor
282+
monitor_args: Dict = field(default_factory=dict)
282283
# ! DO NOT SET, automatically generated as checkpoint_job_dir/monitor
283284
cache_dir: str = ""
284285

trinity/explorer/explorer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from trinity.explorer.runner_pool import RunnerPool
2121
from trinity.manager.manager import CacheManager
2222
from trinity.utils.log import get_logger
23-
from trinity.utils.monitor import Monitor
23+
from trinity.utils.monitor import MONITOR
2424

2525

2626
@ray.remote(name="explorer", concurrency_groups={"get_weight": 32, "setup_weight_sync_group": 1})
@@ -47,7 +47,7 @@ def __init__(self, config: Config):
4747
for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets:
4848
self.eval_tasksets.append(get_buffer_reader(eval_taskset_config, self.config.buffer))
4949
self.runner_pool = self._init_runner_pool()
50-
self.monitor = Monitor(
50+
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
5151
project=self.config.project,
5252
name=self.config.name,
5353
role="explorer",

trinity/trainer/verl_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
pprint,
3535
reduce_metrics,
3636
)
37-
from trinity.utils.monitor import Monitor
37+
from trinity.utils.monitor import MONITOR
3838

3939

4040
class _InternalDataLoader:
@@ -128,7 +128,7 @@ def __init__(
128128
self.algorithm_type = (
129129
AlgorithmType.PPO
130130
) # TODO: initialize algorithm_type according to config
131-
self.logger = Monitor(
131+
self.logger = MONITOR.get(global_config.monitor.monitor_type)(
132132
project=config.trainer.project_name,
133133
name=config.trainer.experiment_name,
134134
role="trainer",

trinity/utils/monitor.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Monitor"""
2+
23
import os
4+
from abc import ABC, abstractmethod
35
from typing import List, Optional, Union
46

57
import numpy as np
@@ -8,11 +10,13 @@
810
from torch.utils.tensorboard import SummaryWriter
911

1012
from trinity.common.config import Config
11-
from trinity.common.constants import MonitorType
1213
from trinity.utils.log import get_logger
14+
from trinity.utils.registry import Registry
15+
16+
MONITOR = Registry("monitor")
1317

1418

15-
class Monitor:
19+
class Monitor(ABC):
1620
"""Monitor"""
1721

1822
def __init__(
@@ -22,15 +26,25 @@ def __init__(
2226
role: str,
2327
config: Config = None, # pass the global Config for recording
2428
) -> None:
25-
if config.monitor.monitor_type == MonitorType.WANDB:
26-
self.logger = WandbLogger(project, name, role, config)
27-
elif config.monitor.monitor_type == MonitorType.TENSORBOARD:
28-
self.logger = TensorboardLogger(project, name, role, config)
29-
else:
30-
raise ValueError(f"Unknown monitor type: {config.monitor.monitor_type}")
29+
self.project = project
30+
self.name = name
31+
self.role = role
32+
self.config = config
3133

34+
@abstractmethod
3235
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
33-
self.logger.log_table(table_name, experiences_table, step=step)
36+
"""Log a table"""
37+
38+
@abstractmethod
39+
def log(self, data: dict, step: int, commit: bool = False) -> None:
40+
"""Log metrics."""
41+
42+
@abstractmethod
43+
def close(self) -> None:
44+
"""Close the monitor"""
45+
46+
def __del__(self) -> None:
47+
self.close()
3448

3549
def calculate_metrics(
3650
self, data: dict[str, Union[List[float], float]], prefix: Optional[str] = None
@@ -51,15 +65,9 @@ def calculate_metrics(
5165
metrics[key] = val
5266
return metrics
5367

54-
def log(self, data: dict, step: int, commit: bool = False) -> None:
55-
"""Log metrics."""
56-
self.logger.log(data, step=step, commit=commit)
57-
58-
def close(self) -> None:
59-
self.logger.close()
60-
6168

62-
class TensorboardLogger:
69+
@MONITOR.register_module("tensorboard")
70+
class TensorboardMonitor(Monitor):
6371
def __init__(self, project: str, name: str, role: str, config: Config = None) -> None:
6472
self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard")
6573
os.makedirs(self.tensorboard_dir, exist_ok=True)
@@ -77,11 +85,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
7785
def close(self) -> None:
7886
self.logger.close()
7987

80-
def __del__(self) -> None:
81-
self.logger.close()
82-
8388

84-
class WandbLogger:
89+
@MONITOR.register_module("wandb")
90+
class WandbMonitor(Monitor):
8591
def __init__(self, project: str, name: str, role: str, config: Config = None) -> None:
8692
self.logger = wandb.init(
8793
project=project,
@@ -104,6 +110,3 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
104110

105111
def close(self) -> None:
106112
self.logger.finish()
107-
108-
def __del__(self) -> None:
109-
self.logger.finish()

0 commit comments

Comments
 (0)