diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 74b5d400e5..b0f354c4fa 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -12,7 +12,6 @@ get_unittest_dataset_config, ) from trinity.cli.launcher import explore -from trinity.common.constants import MonitorType class BaseExplorerCase(RayUnittestBase): @@ -23,7 +22,7 @@ def setUp(self): self.config.model.model_path = get_model_path() self.config.explorer.rollout_model.engine_type = "vllm_async" self.config.algorithm.repeat_times = 2 - self.config.monitor.monitor_type = MonitorType.TENSORBOARD + self.config.monitor.monitor_type = "tensorboard" self.config.project = "Trinity-unittest" self.config.checkpoint_root_dir = get_checkpoint_path() self.config.synchronizer.sync_interval = 2 diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index ac73e46c8d..24e6730a1e 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -15,7 +15,7 @@ get_unittest_dataset_config, ) from trinity.cli.launcher import bench, both -from trinity.common.constants import MonitorType, SyncMethod +from trinity.common.constants import SyncMethod class BaseTrainerCase(RayUnittestBase): @@ -30,7 +30,7 @@ def setUp(self): self.config.explorer.rollout_model.use_v1 = False self.config.project = "Trainer-unittest" self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}" - self.config.monitor.monitor_type = MonitorType.TENSORBOARD + self.config.monitor.monitor_type = "tensorboard" self.config.checkpoint_root_dir = get_checkpoint_path() self.config.synchronizer.sync_interval = 2 self.config.synchronizer.sync_method = SyncMethod.NCCL diff --git a/trinity/common/config.py b/trinity/common/config.py index e0660ab03a..3feec8e2ea 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -8,7 +8,6 @@ from trinity.common.constants import ( AlgorithmType, - MonitorType, PromptType, ReadStrategy, StorageType, @@ -278,7 +277,9 @@ class TrainerConfig: @dataclass class MonitorConfig: # TODO: support multiple monitors (List[MonitorType]) - monitor_type: MonitorType = MonitorType.WANDB + monitor_type: str = "tensorboard" + # the default args for monitor + monitor_args: Dict = field(default_factory=dict) # ! DO NOT SET, automatically generated as checkpoint_job_dir/monitor cache_dir: str = "" diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 9c3cc414c7..bdc0228a65 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -20,7 +20,7 @@ from trinity.explorer.runner_pool import RunnerPool from trinity.manager.manager import CacheManager from trinity.utils.log import get_logger -from trinity.utils.monitor import Monitor +from trinity.utils.monitor import MONITOR @ray.remote(name="explorer", concurrency_groups={"get_weight": 32, "setup_weight_sync_group": 1}) @@ -47,7 +47,7 @@ def __init__(self, config: Config): for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets: self.eval_tasksets.append(get_buffer_reader(eval_taskset_config, self.config.buffer)) self.runner_pool = self._init_runner_pool() - self.monitor = Monitor( + self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, name=self.config.name, role="explorer", diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 7590d6075b..1041600d87 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -34,7 +34,7 @@ pprint, reduce_metrics, ) -from trinity.utils.monitor import Monitor +from trinity.utils.monitor import MONITOR class _InternalDataLoader: @@ -128,7 +128,7 @@ def __init__( self.algorithm_type = ( AlgorithmType.PPO ) # TODO: initialize algorithm_type according to config - self.logger = Monitor( + self.logger = MONITOR.get(global_config.monitor.monitor_type)( project=config.trainer.project_name, name=config.trainer.experiment_name, role="trainer", diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 3044c6dcc8..f12a854335 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -1,5 +1,7 @@ """Monitor""" + import os +from abc import ABC, abstractmethod from typing import List, Optional, Union import numpy as np @@ -8,11 +10,13 @@ from torch.utils.tensorboard import SummaryWriter from trinity.common.config import Config -from trinity.common.constants import MonitorType from trinity.utils.log import get_logger +from trinity.utils.registry import Registry + +MONITOR = Registry("monitor") -class Monitor: +class Monitor(ABC): """Monitor""" def __init__( @@ -22,15 +26,25 @@ def __init__( role: str, config: Config = None, # pass the global Config for recording ) -> None: - if config.monitor.monitor_type == MonitorType.WANDB: - self.logger = WandbLogger(project, name, role, config) - elif config.monitor.monitor_type == MonitorType.TENSORBOARD: - self.logger = TensorboardLogger(project, name, role, config) - else: - raise ValueError(f"Unknown monitor type: {config.monitor.monitor_type}") + self.project = project + self.name = name + self.role = role + self.config = config + @abstractmethod def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): - self.logger.log_table(table_name, experiences_table, step=step) + """Log a table""" + + @abstractmethod + def log(self, data: dict, step: int, commit: bool = False) -> None: + """Log metrics.""" + + @abstractmethod + def close(self) -> None: + """Close the monitor""" + + def __del__(self) -> None: + self.close() def calculate_metrics( self, data: dict[str, Union[List[float], float]], prefix: Optional[str] = None @@ -51,15 +65,9 @@ def calculate_metrics( metrics[key] = val return metrics - def log(self, data: dict, step: int, commit: bool = False) -> None: - """Log metrics.""" - self.logger.log(data, step=step, commit=commit) - - def close(self) -> None: - self.logger.close() - -class TensorboardLogger: +@MONITOR.register_module("tensorboard") +class TensorboardMonitor(Monitor): def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard") os.makedirs(self.tensorboard_dir, exist_ok=True) @@ -77,11 +85,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None: def close(self) -> None: self.logger.close() - def __del__(self) -> None: - self.logger.close() - -class WandbLogger: +@MONITOR.register_module("wandb") +class WandbMonitor(Monitor): def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: self.logger = wandb.init( project=project, @@ -104,6 +110,3 @@ def log(self, data: dict, step: int, commit: bool = False) -> None: def close(self) -> None: self.logger.finish() - - def __del__(self) -> None: - self.logger.finish()