File tree Expand file tree Collapse file tree 2 files changed +8
-7
lines changed
Expand file tree Collapse file tree 2 files changed +8
-7
lines changed Original file line number Diff line number Diff line change 1212import ray
1313
1414from trinity .buffer import get_buffer_reader
15- from trinity .common .config import Config , TrainerConfig
15+ from trinity .common .config import Config
1616from trinity .common .constants import AlgorithmType
1717from trinity .common .experience import Experiences
1818from trinity .utils .log import get_logger
@@ -37,7 +37,7 @@ def __init__(self, config: Config) -> None:
3737 if self .config .trainer .sft_warmup_iteration > 0
3838 else None
3939 )
40- self .engine = get_trainer_wrapper (config . trainer )
40+ self .engine = get_trainer_wrapper (config )
4141
4242 def prepare (self ) -> None :
4343 """Prepare the trainer."""
@@ -146,9 +146,9 @@ def shutdown(self) -> None:
146146 """Shutdown the engine."""
147147
148148
149- def get_trainer_wrapper (config : TrainerConfig ) -> TrainEngineWrapper :
149+ def get_trainer_wrapper (config : Config ) -> TrainEngineWrapper :
150150 """Get a trainer wrapper."""
151- if config .trainer_type == "verl" :
151+ if config .trainer . trainer_type == "verl" :
152152 from trinity .trainer .verl_trainer import VerlPPOTrainerWrapper
153153
154154 return VerlPPOTrainerWrapper (config )
Original file line number Diff line number Diff line change 1313from verl .utils import hf_tokenizer
1414from verl .utils .fs import copy_local_path_from_hdfs
1515
16- from trinity .common .config import TrainerConfig
16+ from trinity .common .config import Config
1717from trinity .common .constants import AlgorithmType
1818from trinity .common .experience import Experiences
1919from trinity .trainer .trainer import TrainEngineWrapper
@@ -71,8 +71,9 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper):
7171
7272 def __init__ (
7373 self ,
74- train_config : TrainerConfig ,
74+ global_config : Config ,
7575 ):
76+ train_config = global_config .trainer
7677 pprint (train_config .trainer_config )
7778 config = OmegaConf .structured (train_config .trainer_config )
7879 # download the checkpoint from hdfs
@@ -134,7 +135,7 @@ def __init__(
134135 project = config .trainer .project_name ,
135136 name = config .trainer .experiment_name ,
136137 role = "trainer" ,
137- config = train_config ,
138+ config = global_config ,
138139 )
139140 self .reset_experiences_example_table ()
140141
You can’t perform that action at this time.
0 commit comments