Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
get_unittest_dataset_config,
)
from trinity.cli.launcher import explore
from trinity.common.constants import MonitorType


class BaseExplorerCase(RayUnittestBase):
Expand All @@ -23,7 +22,7 @@ def setUp(self):
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm_async"
self.config.algorithm.repeat_times = 2
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
self.config.monitor.monitor_type = "tensorboard"
self.config.project = "Trinity-unittest"
self.config.checkpoint_root_dir = get_checkpoint_path()
self.config.synchronizer.sync_interval = 2
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_unittest_dataset_config,
)
from trinity.cli.launcher import bench, both
from trinity.common.constants import MonitorType, SyncMethod
from trinity.common.constants import SyncMethod


class BaseTrainerCase(RayUnittestBase):
Expand All @@ -30,7 +30,7 @@ def setUp(self):
self.config.explorer.rollout_model.use_v1 = False
self.config.project = "Trainer-unittest"
self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
self.config.monitor.monitor_type = "tensorboard"
self.config.checkpoint_root_dir = get_checkpoint_path()
self.config.synchronizer.sync_interval = 2
self.config.synchronizer.sync_method = SyncMethod.NCCL
Expand Down
5 changes: 3 additions & 2 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from trinity.common.constants import (
AlgorithmType,
MonitorType,
PromptType,
ReadStrategy,
StorageType,
Expand Down Expand Up @@ -278,7 +277,9 @@ class TrainerConfig:
@dataclass
class MonitorConfig:
# TODO: support multiple monitors (List[MonitorType])
monitor_type: MonitorType = MonitorType.WANDB
monitor_type: str = "tensorboard"
# the default args for monitor
monitor_args: Dict = field(default_factory=dict)
# ! DO NOT SET, automatically generated as checkpoint_job_dir/monitor
cache_dir: str = ""

Expand Down
4 changes: 2 additions & 2 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from trinity.explorer.runner_pool import RunnerPool
from trinity.manager.manager import CacheManager
from trinity.utils.log import get_logger
from trinity.utils.monitor import Monitor
from trinity.utils.monitor import MONITOR


@ray.remote(name="explorer", concurrency_groups={"get_weight": 32, "setup_weight_sync_group": 1})
Expand All @@ -47,7 +47,7 @@ def __init__(self, config: Config):
for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets:
self.eval_tasksets.append(get_buffer_reader(eval_taskset_config, self.config.buffer))
self.runner_pool = self._init_runner_pool()
self.monitor = Monitor(
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
project=self.config.project,
name=self.config.name,
role="explorer",
Expand Down
4 changes: 2 additions & 2 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
pprint,
reduce_metrics,
)
from trinity.utils.monitor import Monitor
from trinity.utils.monitor import MONITOR


class _InternalDataLoader:
Expand Down Expand Up @@ -128,7 +128,7 @@ def __init__(
self.algorithm_type = (
AlgorithmType.PPO
) # TODO: initialize algorithm_type according to config
self.logger = Monitor(
self.logger = MONITOR.get(global_config.monitor.monitor_type)(
project=config.trainer.project_name,
name=config.trainer.experiment_name,
role="trainer",
Expand Down
51 changes: 27 additions & 24 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Monitor"""

import os
from abc import ABC, abstractmethod
from typing import List, Optional, Union

import numpy as np
Expand All @@ -8,11 +10,13 @@
from torch.utils.tensorboard import SummaryWriter

from trinity.common.config import Config
from trinity.common.constants import MonitorType
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry

MONITOR = Registry("monitor")


class Monitor:
class Monitor(ABC):
"""Monitor"""

def __init__(
Expand All @@ -22,15 +26,25 @@ def __init__(
role: str,
config: Config = None, # pass the global Config for recording
) -> None:
if config.monitor.monitor_type == MonitorType.WANDB:
self.logger = WandbLogger(project, name, role, config)
elif config.monitor.monitor_type == MonitorType.TENSORBOARD:
self.logger = TensorboardLogger(project, name, role, config)
else:
raise ValueError(f"Unknown monitor type: {config.monitor.monitor_type}")
self.project = project
self.name = name
self.role = role
self.config = config

@abstractmethod
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
self.logger.log_table(table_name, experiences_table, step=step)
"""Log a table"""

@abstractmethod
def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""

@abstractmethod
def close(self) -> None:
"""Close the monitor"""

def __del__(self) -> None:
self.close()

def calculate_metrics(
self, data: dict[str, Union[List[float], float]], prefix: Optional[str] = None
Expand All @@ -51,15 +65,9 @@ def calculate_metrics(
metrics[key] = val
return metrics

def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
self.logger.log(data, step=step, commit=commit)

def close(self) -> None:
self.logger.close()


class TensorboardLogger:
@MONITOR.register_module("tensorboard")
class TensorboardMonitor(Monitor):
def __init__(self, project: str, name: str, role: str, config: Config = None) -> None:
self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard")
os.makedirs(self.tensorboard_dir, exist_ok=True)
Expand All @@ -77,11 +85,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
def close(self) -> None:
self.logger.close()

def __del__(self) -> None:
self.logger.close()


class WandbLogger:
@MONITOR.register_module("wandb")
class WandbMonitor(Monitor):
def __init__(self, project: str, name: str, role: str, config: Config = None) -> None:
self.logger = wandb.init(
project=project,
Expand All @@ -104,6 +110,3 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:

def close(self) -> None:
self.logger.finish()

def __del__(self) -> None:
self.logger.finish()