@@ -173,8 +173,7 @@ class ExplorerConfig:
173173@dataclass
174174class TrainerConfig :
175175 trainer_type : str = "verl"
176- trainer_data_type : str = "RFT"
177- trainer_config_path : str = "examples/ppo_countdown/train_countdown.yaml"
176+ trainer_config_path : str = ""
178177 eval_interval : int = 100
179178 enable_preview : bool = True # enable rollout preview in wandb
180179 trainer_config : Any = None
@@ -185,16 +184,6 @@ class TrainerConfig:
185184 # warmup config
186185 sft_warmup_iteration : int = 0
187186
188- def __post_init__ (self ):
189- if self .trainer_type == "verl" :
190- from trinity .common .verl_config import load_config
191-
192- if not os .path .isfile (self .trainer_config_path ):
193- raise ValueError (f"Invalid trainer config path: { self .trainer_config_path } " )
194- self .trainer_config = load_config (self .trainer_config_path )
195- else :
196- raise ValueError (f"Invalid trainer type: { self .trainer_type } " )
197-
198187
199188@dataclass
200189class MonitorConfig :
@@ -285,6 +274,15 @@ def _check_buffer(self) -> None:
285274
286275 def check_and_update (self ) -> None :
287276 """Check and update the config."""
277+ if self .trainer .trainer_type == "verl" :
278+ from trinity .common .verl_config import load_config
279+
280+ if not os .path .isfile (self .trainer .trainer_config_path ):
281+ raise ValueError (f"Invalid trainer config path: { self .trainer .trainer_config_path } " )
282+ self .trainer .trainer_config = load_config (self .trainer .trainer_config_path )
283+ else :
284+ raise ValueError (f"Invalid trainer type: { self .trainer_type } " )
285+
288286 # check mode
289287 if self .mode not in ["explore" , "train" , "both" ]:
290288 raise ValueError (f"Invalid mode: { self .mode } " )
0 commit comments