Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from omegaconf import OmegaConf

from trinity.common.constants import AlgorithmType, PromptType, StorageType
from trinity.common.constants import AlgorithmType, MonitorType, PromptType, StorageType
from trinity.utils.log import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -201,6 +201,7 @@ class MonitorConfig:
# TODO: add more
project: str = "trinity"
name: str = "rft"
monitor_type: MonitorType = MonitorType.WANDB

# ! DO NOT SET
# the root directory for cache and meta files, automatically generated
Expand Down
7 changes: 7 additions & 0 deletions trinity/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,10 @@ def is_rollout(self) -> bool:
def is_dpo(self) -> bool:
"""Check if the algorithm is DPO."""
return self == AlgorithmType.DPO


class MonitorType(CaseInsensitiveEnum):
"""Monitor Type."""

WANDB = "wandb"
TENSORBOARD = "tensorboard"
8 changes: 4 additions & 4 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import ray

from trinity.buffer import get_buffer_reader
from trinity.common.config import Config, TrainerConfig
from trinity.common.config import Config
from trinity.common.constants import AlgorithmType
from trinity.common.experience import Experiences
from trinity.utils.log import get_logger
Expand All @@ -37,7 +37,7 @@ def __init__(self, config: Config) -> None:
if self.config.trainer.sft_warmup_iteration > 0
else None
)
self.engine = get_trainer_wrapper(config.trainer)
self.engine = get_trainer_wrapper(config)

def prepare(self) -> None:
"""Prepare the trainer."""
Expand Down Expand Up @@ -146,9 +146,9 @@ def shutdown(self) -> None:
"""Shutdown the engine."""


def get_trainer_wrapper(config: TrainerConfig) -> TrainEngineWrapper:
def get_trainer_wrapper(config: Config) -> TrainEngineWrapper:
"""Get a trainer wrapper."""
if config.trainer_type == "verl":
if config.trainer.trainer_type == "verl":
from trinity.trainer.verl_trainer import VerlPPOTrainerWrapper

return VerlPPOTrainerWrapper(config)
Expand Down
7 changes: 4 additions & 3 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_local_path_from_hdfs

from trinity.common.config import TrainerConfig
from trinity.common.config import Config
from trinity.common.constants import AlgorithmType
from trinity.common.experience import Experiences
from trinity.trainer.trainer import TrainEngineWrapper
Expand Down Expand Up @@ -71,8 +71,9 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper):

def __init__(
self,
train_config: TrainerConfig,
global_config: Config,
):
train_config = global_config.trainer
pprint(train_config.trainer_config)
config = OmegaConf.structured(train_config.trainer_config)
# download the checkpoint from hdfs
Expand Down Expand Up @@ -134,7 +135,7 @@ def __init__(
project=config.trainer.project_name,
name=config.trainer.experiment_name,
role="trainer",
config=train_config,
config=global_config,
)
self.reset_experiences_example_table()

Expand Down
62 changes: 50 additions & 12 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Monitor"""

import os
from typing import Any, List, Optional, Union

import numpy as np
import pandas as pd
import wandb
from torch.utils.tensorboard import SummaryWriter

from trinity.common.constants import MonitorType
from trinity.utils.log import get_logger


Expand All @@ -19,19 +21,15 @@ def __init__(
role: str,
config: Any = None,
) -> None:
self.logger = wandb.init(
project=project,
group=name,
name=f"{name}_{role}",
tags=[role],
config=config,
save_code=False,
)
self.console_logger = get_logger(__name__)
if config.monitor.monitor_type == MonitorType.WANDB:
self.logger = WandbLogger(project, name, role, config)
elif config.monitor.monitor_type == MonitorType.TENSORBOARD:
self.logger = TensorboardLogger(project, name, role, config)
else:
raise ValueError(f"Unknown monitor type: {config.monitor.monitor_type}")

def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
experiences_table = wandb.Table(dataframe=experiences_table)
self.log(data={table_name: experiences_table}, step=step)
self.logger.log_table(table_name, experiences_table, step=step)

def calculate_metrics(
self, data: dict[str, Union[List[float], float]], prefix: Optional[str] = None
Expand All @@ -55,6 +53,46 @@ def calculate_metrics(
def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
self.logger.log(data, step=step, commit=commit)


class TensorboardLogger:
def __init__(self, project: str, name: str, role: str, config: Any = None) -> None:
self.tensorboard_dir = os.path.join(config.monitor.job_dir, "tensorboard")
os.makedirs(self.tensorboard_dir, exist_ok=True)
self.logger = SummaryWriter(self.tensorboard_dir)
self.console_logger = get_logger(__name__)

def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
pass

def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
for key in data:
self.logger.add_scalar(key, data[key], step)

def __del__(self) -> None:
self.logger.close()


class WandbLogger:
def __init__(self, project: str, name: str, role: str, config: Any = None) -> None:
self.logger = wandb.init(
project=project,
group=name,
name=f"{name}_{role}",
tags=[role],
config=config,
save_code=False,
)
self.console_logger = get_logger(__name__)

def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
experiences_table = wandb.Table(dataframe=experiences_table)
self.log(data={table_name: experiences_table}, step=step)

def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
self.logger.log(data, step=step, commit=commit)
self.console_logger.info(f"Step {step}: {data}")

def __del__(self) -> None:
Expand Down