@@ -100,7 +100,7 @@ class DatasetConfig:
100100
101101 name : str
102102 storage_type : StorageType
103- algorithm_type : AlgorithmType
103+ algorithm_type : AlgorithmType = AlgorithmType . PPO
104104 path : Optional [str ] = None
105105 kwargs : Dict [str , Any ] = field (default_factory = dict )
106106
@@ -176,7 +176,7 @@ class TrainerConfig:
176176 trainer_config_path : str = ""
177177 eval_interval : int = 100
178178 enable_preview : bool = True # enable rollout preview in wandb
179- trainer_config : Any = None
179+ trainer_config : Any = field ( default_factory = dict )
180180
181181 # train algorithm
182182 algorithm_type : AlgorithmType = AlgorithmType .PPO
@@ -273,11 +273,21 @@ def _check_buffer(self) -> None:
273273 def check_and_update (self ) -> None :
274274 """Check and update the config."""
275275 if self .trainer .trainer_type == "verl" :
276- from trinity .common .verl_config import load_config
277-
278- if not os .path .isfile (self .trainer .trainer_config_path ):
279- raise ValueError (f"Invalid trainer config path: { self .trainer .trainer_config_path } " )
280- self .trainer .trainer_config = load_config (self .trainer .trainer_config_path )
276+ if self .trainer .trainer_config :
277+ from trinity .common .verl_config import veRLConfig
278+
279+ trainer_config_schema = OmegaConf .structured (veRLConfig )
280+ trainer_config = OmegaConf .merge (trainer_config_schema , self .trainer .trainer_config )
281+ self .trainer .trainer_config = OmegaConf .to_object (trainer_config )
282+ else :
283+ if os .path .isfile (self .trainer .trainer_config_path ):
284+ from trinity .common .verl_config import load_config
285+
286+ self .trainer .trainer_config = load_config (self .trainer .trainer_config_path )
287+ else :
288+ raise ValueError (
289+ f"Invalid trainer config path: { self .trainer .trainer_config_path } "
290+ )
281291 else :
282292 raise ValueError (f"Invalid trainer type: { self .trainer_type } " )
283293
0 commit comments