Skip to content

Commit ffe22ac

Browse files
authored
Add tensorboard monitor (#21)
1 parent f0a6024 commit ffe22ac

File tree

5 files changed

+67
-20
lines changed

5 files changed

+67
-20
lines changed

trinity/common/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from omegaconf import OmegaConf
88

9-
from trinity.common.constants import AlgorithmType, PromptType, StorageType
9+
from trinity.common.constants import AlgorithmType, MonitorType, PromptType, StorageType
1010
from trinity.utils.log import get_logger
1111

1212
logger = get_logger(__name__)
@@ -201,6 +201,7 @@ class MonitorConfig:
201201
# TODO: add more
202202
project: str = "trinity"
203203
name: str = "rft"
204+
monitor_type: MonitorType = MonitorType.WANDB
204205

205206
# ! DO NOT SET
206207
# the root directory for cache and meta files, automatically generated

trinity/common/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,10 @@ def is_rollout(self) -> bool:
8787
def is_dpo(self) -> bool:
8888
"""Check if the algorithm is DPO."""
8989
return self == AlgorithmType.DPO
90+
91+
92+
class MonitorType(CaseInsensitiveEnum):
93+
"""Monitor Type."""
94+
95+
WANDB = "wandb"
96+
TENSORBOARD = "tensorboard"

trinity/trainer/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import ray
1313

1414
from trinity.buffer import get_buffer_reader
15-
from trinity.common.config import Config, TrainerConfig
15+
from trinity.common.config import Config
1616
from trinity.common.constants import AlgorithmType
1717
from trinity.common.experience import Experiences
1818
from trinity.utils.log import get_logger
@@ -37,7 +37,7 @@ def __init__(self, config: Config) -> None:
3737
if self.config.trainer.sft_warmup_iteration > 0
3838
else None
3939
)
40-
self.engine = get_trainer_wrapper(config.trainer)
40+
self.engine = get_trainer_wrapper(config)
4141

4242
def prepare(self) -> None:
4343
"""Prepare the trainer."""
@@ -146,9 +146,9 @@ def shutdown(self) -> None:
146146
"""Shutdown the engine."""
147147

148148

149-
def get_trainer_wrapper(config: TrainerConfig) -> TrainEngineWrapper:
149+
def get_trainer_wrapper(config: Config) -> TrainEngineWrapper:
150150
"""Get a trainer wrapper."""
151-
if config.trainer_type == "verl":
151+
if config.trainer.trainer_type == "verl":
152152
from trinity.trainer.verl_trainer import VerlPPOTrainerWrapper
153153

154154
return VerlPPOTrainerWrapper(config)

trinity/trainer/verl_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from verl.utils import hf_tokenizer
1414
from verl.utils.fs import copy_local_path_from_hdfs
1515

16-
from trinity.common.config import TrainerConfig
16+
from trinity.common.config import Config
1717
from trinity.common.constants import AlgorithmType
1818
from trinity.common.experience import Experiences
1919
from trinity.trainer.trainer import TrainEngineWrapper
@@ -71,8 +71,9 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper):
7171

7272
def __init__(
7373
self,
74-
train_config: TrainerConfig,
74+
global_config: Config,
7575
):
76+
train_config = global_config.trainer
7677
pprint(train_config.trainer_config)
7778
config = OmegaConf.structured(train_config.trainer_config)
7879
# download the checkpoint from hdfs
@@ -134,7 +135,7 @@ def __init__(
134135
project=config.trainer.project_name,
135136
name=config.trainer.experiment_name,
136137
role="trainer",
137-
config=train_config,
138+
config=global_config,
138139
)
139140
self.reset_experiences_example_table()
140141

trinity/utils/monitor.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Monitor"""
2-
2+
import os
33
from typing import Any, List, Optional, Union
44

55
import numpy as np
66
import pandas as pd
77
import wandb
8+
from torch.utils.tensorboard import SummaryWriter
89

10+
from trinity.common.constants import MonitorType
911
from trinity.utils.log import get_logger
1012

1113

@@ -19,19 +21,15 @@ def __init__(
1921
role: str,
2022
config: Any = None,
2123
) -> None:
22-
self.logger = wandb.init(
23-
project=project,
24-
group=name,
25-
name=f"{name}_{role}",
26-
tags=[role],
27-
config=config,
28-
save_code=False,
29-
)
30-
self.console_logger = get_logger(__name__)
24+
if config.monitor.monitor_type == MonitorType.WANDB:
25+
self.logger = WandbLogger(project, name, role, config)
26+
elif config.monitor.monitor_type == MonitorType.TENSORBOARD:
27+
self.logger = TensorboardLogger(project, name, role, config)
28+
else:
29+
raise ValueError(f"Unknown monitor type: {config.monitor.monitor_type}")
3130

3231
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
33-
experiences_table = wandb.Table(dataframe=experiences_table)
34-
self.log(data={table_name: experiences_table}, step=step)
32+
self.logger.log_table(table_name, experiences_table, step=step)
3533

3634
def calculate_metrics(
3735
self, data: dict[str, Union[List[float], float]], prefix: Optional[str] = None
@@ -55,6 +53,46 @@ def calculate_metrics(
5553
def log(self, data: dict, step: int, commit: bool = False) -> None:
5654
"""Log metrics."""
5755
self.logger.log(data, step=step, commit=commit)
56+
57+
58+
class TensorboardLogger:
59+
def __init__(self, project: str, name: str, role: str, config: Any = None) -> None:
60+
self.tensorboard_dir = os.path.join(config.monitor.job_dir, "tensorboard")
61+
os.makedirs(self.tensorboard_dir, exist_ok=True)
62+
self.logger = SummaryWriter(self.tensorboard_dir)
63+
self.console_logger = get_logger(__name__)
64+
65+
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
66+
pass
67+
68+
def log(self, data: dict, step: int, commit: bool = False) -> None:
69+
"""Log metrics."""
70+
for key in data:
71+
self.logger.add_scalar(key, data[key], step)
72+
73+
def __del__(self) -> None:
74+
self.logger.close()
75+
76+
77+
class WandbLogger:
78+
def __init__(self, project: str, name: str, role: str, config: Any = None) -> None:
79+
self.logger = wandb.init(
80+
project=project,
81+
group=name,
82+
name=f"{name}_{role}",
83+
tags=[role],
84+
config=config,
85+
save_code=False,
86+
)
87+
self.console_logger = get_logger(__name__)
88+
89+
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
90+
experiences_table = wandb.Table(dataframe=experiences_table)
91+
self.log(data={table_name: experiences_table}, step=step)
92+
93+
def log(self, data: dict, step: int, commit: bool = False) -> None:
94+
"""Log metrics."""
95+
self.logger.log(data, step=step, commit=commit)
5896
self.console_logger.info(f"Step {step}: {data}")
5997

6098
def __del__(self) -> None:

0 commit comments

Comments
 (0)