Skip to content

Commit bdeb8ce

Browse files
committed
fix tensorboard monitor
1 parent d577f21 commit bdeb8ce

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

trinity/trainer/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import ray
1313

1414
from trinity.buffer import get_buffer_reader
15-
from trinity.common.config import Config, TrainerConfig
15+
from trinity.common.config import Config
1616
from trinity.common.constants import AlgorithmType
1717
from trinity.common.experience import Experiences
1818
from trinity.utils.log import get_logger
@@ -37,7 +37,7 @@ def __init__(self, config: Config) -> None:
3737
if self.config.trainer.sft_warmup_iteration > 0
3838
else None
3939
)
40-
self.engine = get_trainer_wrapper(config.trainer)
40+
self.engine = get_trainer_wrapper(config)
4141

4242
def prepare(self) -> None:
4343
"""Prepare the trainer."""
@@ -146,9 +146,9 @@ def shutdown(self) -> None:
146146
"""Shutdown the engine."""
147147

148148

149-
def get_trainer_wrapper(config: TrainerConfig) -> TrainEngineWrapper:
149+
def get_trainer_wrapper(config: Config) -> TrainEngineWrapper:
150150
"""Get a trainer wrapper."""
151-
if config.trainer_type == "verl":
151+
if config.trainer.trainer_type == "verl":
152152
from trinity.trainer.verl_trainer import VerlPPOTrainerWrapper
153153

154154
return VerlPPOTrainerWrapper(config)

trinity/trainer/verl_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from verl.utils import hf_tokenizer
1414
from verl.utils.fs import copy_local_path_from_hdfs
1515

16-
from trinity.common.config import TrainerConfig
16+
from trinity.common.config import Config
1717
from trinity.common.constants import AlgorithmType
1818
from trinity.common.experience import Experiences
1919
from trinity.trainer.trainer import TrainEngineWrapper
@@ -71,8 +71,9 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper):
7171

7272
def __init__(
7373
self,
74-
train_config: TrainerConfig,
74+
global_config: Config,
7575
):
76+
train_config = global_config.trainer
7677
pprint(train_config.trainer_config)
7778
config = OmegaConf.structured(train_config.trainer_config)
7879
# download the checkpoint from hdfs
@@ -134,7 +135,7 @@ def __init__(
134135
project=config.trainer.project_name,
135136
name=config.trainer.experiment_name,
136137
role="trainer",
137-
config=train_config,
138+
config=global_config,
138139
)
139140
self.reset_experiences_example_table()
140141

0 commit comments

Comments
 (0)