Skip to content

Commit 1b03d27

Browse files
committed
Add some help messages and Allow users to not set traine_config_path
1 parent 1095bdd commit 1b03d27

File tree

2 files changed

+321
-303
lines changed

2 files changed

+321
-303
lines changed

trinity/common/config.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class DatasetConfig:
100100

101101
name: str
102102
storage_type: StorageType
103-
algorithm_type: AlgorithmType
103+
algorithm_type: AlgorithmType = AlgorithmType.PPO
104104
path: Optional[str] = None
105105
kwargs: Dict[str, Any] = field(default_factory=dict)
106106

@@ -176,7 +176,7 @@ class TrainerConfig:
176176
trainer_config_path: str = ""
177177
eval_interval: int = 100
178178
enable_preview: bool = True # enable rollout preview in wandb
179-
trainer_config: Any = None
179+
trainer_config: Any = field(default_factory=dict)
180180

181181
# train algorithm
182182
algorithm_type: AlgorithmType = AlgorithmType.PPO
@@ -273,11 +273,21 @@ def _check_buffer(self) -> None:
273273
def check_and_update(self) -> None:
274274
"""Check and update the config."""
275275
if self.trainer.trainer_type == "verl":
276-
from trinity.common.verl_config import load_config
277-
278-
if not os.path.isfile(self.trainer.trainer_config_path):
279-
raise ValueError(f"Invalid trainer config path: {self.trainer.trainer_config_path}")
280-
self.trainer.trainer_config = load_config(self.trainer.trainer_config_path)
276+
if self.trainer.trainer_config:
277+
from trinity.common.verl_config import veRLConfig
278+
279+
trainer_config_schema = OmegaConf.structured(veRLConfig)
280+
trainer_config = OmegaConf.merge(trainer_config_schema, self.trainer.trainer_config)
281+
self.trainer.trainer_config = OmegaConf.to_object(trainer_config)
282+
else:
283+
if os.path.isfile(self.trainer.trainer_config_path):
284+
from trinity.common.verl_config import load_config
285+
286+
self.trainer.trainer_config = load_config(self.trainer.trainer_config_path)
287+
else:
288+
raise ValueError(
289+
f"Invalid trainer config path: {self.trainer.trainer_config_path}"
290+
)
281291
else:
282292
raise ValueError(f"Invalid trainer type: {self.trainer_type}")
283293

0 commit comments

Comments
 (0)