diff --git a/trinity/common/config.py b/trinity/common/config.py index 1434e7a833..1e0b0c2509 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -6,7 +6,7 @@ from omegaconf import OmegaConf -from trinity.common.constants import AlgorithmType, PromptType, StorageType +from trinity.common.constants import AlgorithmType, MonitorType, PromptType, StorageType from trinity.utils.log import get_logger logger = get_logger(__name__) @@ -201,6 +201,7 @@ class MonitorConfig: # TODO: add more project: str = "trinity" name: str = "rft" + monitor_type: MonitorType = MonitorType.WANDB # ! DO NOT SET # the root directory for cache and meta files, automatically generated diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 154eb360d2..bb0d967e6b 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -87,3 +87,10 @@ def is_rollout(self) -> bool: def is_dpo(self) -> bool: """Check if the algorithm is DPO.""" return self == AlgorithmType.DPO + + +class MonitorType(CaseInsensitiveEnum): + """Monitor Type.""" + + WANDB = "wandb" + TENSORBOARD = "tensorboard" diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index b57f484950..f35c55af0a 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -12,7 +12,7 @@ import ray from trinity.buffer import get_buffer_reader -from trinity.common.config import Config, TrainerConfig +from trinity.common.config import Config from trinity.common.constants import AlgorithmType from trinity.common.experience import Experiences from trinity.utils.log import get_logger @@ -37,7 +37,7 @@ def __init__(self, config: Config) -> None: if self.config.trainer.sft_warmup_iteration > 0 else None ) - self.engine = get_trainer_wrapper(config.trainer) + self.engine = get_trainer_wrapper(config) def prepare(self) -> None: """Prepare the trainer.""" @@ -146,9 +146,9 @@ def shutdown(self) -> None: """Shutdown the engine.""" -def get_trainer_wrapper(config: TrainerConfig) -> TrainEngineWrapper: +def get_trainer_wrapper(config: Config) -> TrainEngineWrapper: """Get a trainer wrapper.""" - if config.trainer_type == "verl": + if config.trainer.trainer_type == "verl": from trinity.trainer.verl_trainer import VerlPPOTrainerWrapper return VerlPPOTrainerWrapper(config) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 03c60cf2bb..89f649d3b6 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -13,7 +13,7 @@ from verl.utils import hf_tokenizer from verl.utils.fs import copy_local_path_from_hdfs -from trinity.common.config import TrainerConfig +from trinity.common.config import Config from trinity.common.constants import AlgorithmType from trinity.common.experience import Experiences from trinity.trainer.trainer import TrainEngineWrapper @@ -71,8 +71,9 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): def __init__( self, - train_config: TrainerConfig, + global_config: Config, ): + train_config = global_config.trainer pprint(train_config.trainer_config) config = OmegaConf.structured(train_config.trainer_config) # download the checkpoint from hdfs @@ -134,7 +135,7 @@ def __init__( project=config.trainer.project_name, name=config.trainer.experiment_name, role="trainer", - config=train_config, + config=global_config, ) self.reset_experiences_example_table() diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 2ff2a61a70..23b96a3c11 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -1,11 +1,13 @@ """Monitor""" - +import os from typing import Any, List, Optional, Union import numpy as np import pandas as pd import wandb +from torch.utils.tensorboard import SummaryWriter +from trinity.common.constants import MonitorType from trinity.utils.log import get_logger @@ -19,19 +21,15 @@ def __init__( role: str, config: Any = None, ) -> None: - self.logger = wandb.init( - project=project, - group=name, - name=f"{name}_{role}", - tags=[role], - config=config, - save_code=False, - ) - self.console_logger = get_logger(__name__) + if config.monitor.monitor_type == MonitorType.WANDB: + self.logger = WandbLogger(project, name, role, config) + elif config.monitor.monitor_type == MonitorType.TENSORBOARD: + self.logger = TensorboardLogger(project, name, role, config) + else: + raise ValueError(f"Unknown monitor type: {config.monitor.monitor_type}") def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): - experiences_table = wandb.Table(dataframe=experiences_table) - self.log(data={table_name: experiences_table}, step=step) + self.logger.log_table(table_name, experiences_table, step=step) def calculate_metrics( self, data: dict[str, Union[List[float], float]], prefix: Optional[str] = None @@ -55,6 +53,46 @@ def calculate_metrics( def log(self, data: dict, step: int, commit: bool = False) -> None: """Log metrics.""" self.logger.log(data, step=step, commit=commit) + + +class TensorboardLogger: + def __init__(self, project: str, name: str, role: str, config: Any = None) -> None: + self.tensorboard_dir = os.path.join(config.monitor.job_dir, "tensorboard") + os.makedirs(self.tensorboard_dir, exist_ok=True) + self.logger = SummaryWriter(self.tensorboard_dir) + self.console_logger = get_logger(__name__) + + def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): + pass + + def log(self, data: dict, step: int, commit: bool = False) -> None: + """Log metrics.""" + for key in data: + self.logger.add_scalar(key, data[key], step) + + def __del__(self) -> None: + self.logger.close() + + +class WandbLogger: + def __init__(self, project: str, name: str, role: str, config: Any = None) -> None: + self.logger = wandb.init( + project=project, + group=name, + name=f"{name}_{role}", + tags=[role], + config=config, + save_code=False, + ) + self.console_logger = get_logger(__name__) + + def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): + experiences_table = wandb.Table(dataframe=experiences_table) + self.log(data={table_name: experiences_table}, step=step) + + def log(self, data: dict, step: int, commit: bool = False) -> None: + """Log metrics.""" + self.logger.log(data, step=step, commit=commit) self.console_logger.info(f"Step {step}: {data}") def __del__(self) -> None: