Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
The following is the main config file for Trinity-RFT. Take `countdown.yaml` as an example.


## Monitor

```yaml
monitor:
project: "Trinity-RFT-countdown"
name: "qwen2.5-1.5B-countdown"
```

- `monitor.project`: The project name. It must be set manually.
- `monitor.name`: The name of the experiment. It must be set manually.


## Monitor

```yaml
Expand Down
1 change: 0 additions & 1 deletion examples/dpo_humanlike/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ buffer:
train_dataset:
name: dpo_buffer
storage_type: file
algorithm_type: dpo
path: '/PATH/TO/DATASET/'
kwargs:
prompt_type: plaintext # plaintext/messages
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ buffer:
train_dataset:
name: alfworld_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:///alfworld.db'
explorer:
engine_type: vllm_async
Expand Down
2 changes: 0 additions & 2 deletions examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ buffer:
train_dataset:
name: gsm8k_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:////gsm8k.db'
# sft_warmup_dataset: # Uncomment these to enable sft warmup
# name: warmup_data
# storage_type: file
# algorithm_type: sft
# path: '/PATH/TO/WARMUP_DATA/'
# kwargs:
# prompt_type: plaintext
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_webshop/webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ buffer:
train_dataset:
name: webshop_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:///webshop.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/opmd_gsm8k/opmd_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ buffer:
train_dataset:
name: gsm8k_buffer
storage_type: queue
algorithm_type: opmd
path: 'sqlite:///gsm8k_opmd.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_countdown/countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ buffer:
train_dataset:
name: countdown_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:////countdown.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"math_verify",
"ninja",
"fire",
"streamlit",
"flask",
"requests",
"tensorboard",
Expand Down
10 changes: 4 additions & 6 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class ExplorerConfig:
# For async engine (vllm_async), it can be larger than `engine_num`, e.g. 16 * `engine_num`
runner_num: int = 1

# repeat each task for `repeat_times` times (for GPRO-like algrorithms)
# repeat each task for `repeat_times` times (for GPRO-like algorithms)
repeat_times: int = 1

# for rollout tokneize
Expand Down Expand Up @@ -265,11 +265,9 @@ def _check_buffer(self) -> None:
else:
if self.buffer.train_dataset is None:
raise ValueError("buffer.train_dataset is required when mode is not 'both'")
if self.buffer.train_dataset.algorithm_type != self.trainer.algorithm_type:
raise ValueError(
f"buffer.train_dataset.algorithm_type ({self.buffer.train_dataset.algorithm_type}) "
f"is not consistent with trainer.algorithm_type ({self.trainer.algorithm_type})"
)
self.buffer.train_dataset.algorithm_type = self.trainer.algorithm_type
if self.buffer.sft_warmup_dataset is not None:
self.buffer.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT
self.buffer.read_batch_size = self.data.batch_size * self.explorer.repeat_times

def check_and_update(self) -> None:
Expand Down
Loading