Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
The following is the main config file for Trinity-RFT. Take `countdown.yaml` as an example.


## Monitor

```yaml
monitor:
project: "Trinity-RFT-countdown"
name: "qwen2.5-1.5B-countdown"
```

- `monitor.project`: The project name. It must be set manually.
- `monitor.name`: The name of the experiment. It must be set manually.


## Monitor

```yaml
Expand Down
1 change: 0 additions & 1 deletion examples/dpo_humanlike/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ buffer:
train_dataset:
name: dpo_buffer
storage_type: file
algorithm_type: dpo
path: '/PATH/TO/DATASET/'
kwargs:
prompt_type: plaintext # plaintext/messages
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ buffer:
train_dataset:
name: alfworld_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:///alfworld.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ buffer:
# sft_warmup_dataset: # Uncomment these to enable sft warmup
# name: warmup_data
# storage_type: file
# algorithm_type: sft
# path: '/PATH/TO/WARMUP_DATA/'
# kwargs:
# prompt_type: plaintext
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_webshop/webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ buffer:
train_dataset:
name: webshop_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:///webshop.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/opmd_gsm8k/opmd_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ buffer:
train_dataset:
name: gsm8k_buffer
storage_type: queue
algorithm_type: opmd
path: 'sqlite:///gsm8k_opmd.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_countdown/countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ buffer:
train_dataset:
name: countdown_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:////countdown.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"math_verify",
"ninja",
"fire",
"streamlit",
"flask",
"requests",
"tensorboard",
Expand Down
34 changes: 21 additions & 13 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class DatasetConfig:

name: str
storage_type: StorageType
algorithm_type: AlgorithmType
algorithm_type: AlgorithmType = AlgorithmType.PPO
path: Optional[str] = None
kwargs: Dict[str, Any] = field(default_factory=dict)

Expand Down Expand Up @@ -143,7 +143,7 @@ class ExplorerConfig:
# For async engine (vllm_async), it can be larger than `engine_num`, e.g. 16 * `engine_num`
runner_num: int = 1

# repeat each task for `repeat_times` times (for GPRO-like algrorithms)
# repeat each task for `repeat_times` times (for GPRO-like algorithms)
repeat_times: int = 1

# for rollout tokneize
Expand Down Expand Up @@ -177,7 +177,7 @@ class TrainerConfig:
trainer_config_path: str = ""
eval_interval: int = 100
enable_preview: bool = True # enable rollout preview in wandb
trainer_config: Any = None
trainer_config: Any = field(default_factory=dict)

# train algorithm
algorithm_type: AlgorithmType = AlgorithmType.PPO
Expand Down Expand Up @@ -266,21 +266,29 @@ def _check_buffer(self) -> None:
else:
if self.buffer.train_dataset is None:
raise ValueError("buffer.train_dataset is required when mode is not 'both'")
if self.buffer.train_dataset.algorithm_type != self.trainer.algorithm_type:
raise ValueError(
f"buffer.train_dataset.algorithm_type ({self.buffer.train_dataset.algorithm_type}) "
f"is not consistent with trainer.algorithm_type ({self.trainer.algorithm_type})"
)
self.buffer.train_dataset.algorithm_type = self.trainer.algorithm_type
if self.buffer.sft_warmup_dataset is not None:
self.buffer.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT
self.buffer.read_batch_size = self.data.batch_size * self.explorer.repeat_times

def check_and_update(self) -> None:
"""Check and update the config."""
if self.trainer.trainer_type == "verl":
from trinity.common.verl_config import load_config

if not os.path.isfile(self.trainer.trainer_config_path):
raise ValueError(f"Invalid trainer config path: {self.trainer.trainer_config_path}")
self.trainer.trainer_config = load_config(self.trainer.trainer_config_path)
if self.trainer.trainer_config:
from trinity.common.verl_config import veRLConfig

trainer_config_schema = OmegaConf.structured(veRLConfig)
trainer_config = OmegaConf.merge(trainer_config_schema, self.trainer.trainer_config)
self.trainer.trainer_config = OmegaConf.to_object(trainer_config)
else:
if os.path.isfile(self.trainer.trainer_config_path):
from trinity.common.verl_config import load_config

self.trainer.trainer_config = load_config(self.trainer.trainer_config_path)
else:
raise ValueError(
f"Invalid trainer config path: {self.trainer.trainer_config_path}"
)
else:
raise ValueError(f"Invalid trainer type: {self.trainer_type}")

Expand Down
Loading