Skip to content

Commit b8289ad

Browse files
committed
fix some comments
1 parent 5d8b6d0 commit b8289ad

File tree

12 files changed

+37
-27
lines changed

12 files changed

+37
-27
lines changed

docs/sphinx_doc/source/tutorial/example_async_mode.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Trinity-RFT supports an asynchronous mode by running the trainer and explorer in
77
For this purpose, we prepare two main config files: `trainer.yaml` and `explorer.yaml`.
88
The main difference between them is that in `trainer.yaml` we set `mode=train`, while in `explorer.yaml` we set `mode=explore`.
99
In addition, we need to configure the following parameters in both files.
10-
The model weights of the explorer and trainer are synchronized once every `sync_iteration_interval * batch_size` tasks.
10+
The model weights of the explorer and trainer are synchronized once every `sync_interval * batch_size` tasks.
1111

1212
```yaml
1313
project: tutorial
@@ -24,7 +24,7 @@ buffer:
2424

2525
synchronizer:
2626
sync_method: 'checkpoint'
27-
sync_interval: <sync_iteration_interval>
27+
sync_interval: <sync_interval>
2828
```
2929
3030
You may run this example with the following command:

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@ The following is the main config file for Trinity-RFT. Take `countdown.yaml` as
55
## Global Config
66

77
```yaml
8-
mode: both
98
project: Trinity-RFT
109
name: example
11-
checkpoint_root_dir: /PATH/TO/CHECKPOINT_DIR
10+
mode: both
11+
checkpoint_root_dir: /PATH/TO/CHECKPOINT
1212
```
1313
14-
- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`.
1514
- `project`: The name of the project.
1615
- `name`: The name of the experiment.
17-
- `checkpoint_root_dir`: The root directory of the checkpoint.
16+
- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`.
17+
- `checkpoint_root_dir`: The root directory to save the checkpoints. Sepcifically, the generated checkpoints will be saved in `<checkpoint_root_dir>/<project>/<name>/.
1818

1919
## Algorithm
2020

@@ -24,7 +24,7 @@ algorithm:
2424
repeat_times: 1
2525
```
2626

27-
- `algorithm.algorithm_type`: The type of the algorithm, Support `ppo`, `grpo`, `opmd` and `dpo`.
27+
- `algorithm.algorithm_type`: The type of the algorithm. Support `ppo`, `grpo`, `opmd` and `dpo`.
2828
- `algorithm.repeat_times`: The number of times to repeat each task. Used for GRPO-like algorithm. Default is `1`.
2929

3030
## Monitor

examples/grpo_alfworld/alfworld.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ algorithm:
66
repeat_times: 8
77
model:
88
model_path: /PATH/TO/MODEL/
9+
max_prompt_tokens: 4096
10+
max_response_tokens: 16384
911
cluster:
1012
node_num: 1
1113
gpu_per_node: 8
@@ -39,8 +41,6 @@ explorer:
3941
tensor_parallel_size: 2
4042
enable_prefix_caching: false
4143
enforce_eager: true
42-
max_prompt_tokens: 4096
43-
max_response_tokens: 16384
4444
dtype: bfloat16
4545
seed: 42
4646
gpu_memory_utilization: 0.7

examples/grpo_gsm8k/gsm8k.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ data_processor:
2020

2121
model:
2222
model_path: '/PATH/TO/MODEL/'
23+
max_prompt_tokens: 256
24+
max_response_tokens: 1024
2325
cluster:
2426
node_num: 1
2527
gpu_per_node: 8
@@ -71,8 +73,6 @@ explorer:
7173
enable_prefix_caching: false
7274
enforce_eager: true
7375
dtype: bfloat16
74-
max_prompt_tokens: 256
75-
max_response_tokens: 1024
7676
seed: 42
7777
synchronizer:
7878
sync_method: 'nccl'

examples/grpo_sciworld/sciworld.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ algorithm:
66
repeat_times: 8
77
model:
88
model_path: /PATH/TO/MODEL/
9+
max_prompt_tokens: 4096
10+
max_response_tokens: 16384
911
cluster:
1012
node_num: 1
1113
gpu_per_node: 8
@@ -41,8 +43,6 @@ explorer:
4143
enforce_eager: true
4244
dtype: bfloat16
4345
seed: 42
44-
max_prompt_tokens: 4096
45-
max_response_tokens: 16384
4646
gpu_memory_utilization: 0.7
4747
enable_chunked_prefill: true
4848
synchronizer:

examples/grpo_webshop/webshop.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ algorithm:
66
repeat_times: 8
77
model:
88
model_path: /PATH/TO/MODEL/
9+
max_prompt_tokens: 4096
10+
max_response_tokens: 16384
911
cluster:
1012
node_num: 1
1113
gpu_per_node: 8
@@ -40,8 +42,6 @@ explorer:
4042
enable_prefix_caching: false
4143
enforce_eager: true
4244
dtype: bfloat16
43-
max_prompt_tokens: 4096
44-
max_response_tokens: 16384
4545
seed: 42
4646
gpu_memory_utilization: 0.7
4747
enable_chunked_prefill: true

examples/opmd_gsm8k/opmd_gsm8k.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ algorithm:
66
repeat_times: 8
77
model:
88
model_path: /PATH/TO/MODEL/
9+
max_prompt_tokens: 4096
10+
max_response_tokens: 16384
911
cluster:
1012
node_num: 1
1113
gpu_per_node: 8
@@ -40,8 +42,6 @@ explorer:
4042
tensor_parallel_size: 1
4143
enable_prefix_caching: false
4244
enforce_eager: true
43-
max_prompt_tokens: 4096
44-
max_response_tokens: 16384
4545
dtype: bfloat16
4646
seed: 42
4747
synchronizer:

tests/template/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ algorithm:
77
repeat_times: 1
88
model:
99
model_path: ''
10+
max_prompt_tokens: 2048
11+
max_response_tokens: 2048
1012
cluster: # 2 for explorer, 2 for trainer
1113
node_num: 1
1214
gpu_per_node: 4
@@ -33,8 +35,6 @@ explorer:
3335
enable_prefix_caching: false
3436
enforce_eager: true
3537
dtype: bfloat16
36-
max_prompt_tokens: 2048
37-
max_response_tokens: 2048
3838
seed: 42
3939
use_v1: true
4040
trainer:

trinity/common/config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ class ModelConfig:
118118
# source model path
119119
model_path: str = ""
120120
critic_model_path: str = ""
121-
max_prompt_tokens: int = 2048
122-
max_response_tokens: int = 2048
121+
max_prompt_tokens: Optional[int] = None
122+
max_response_tokens: Optional[int] = None
123123

124124

125125
@dataclass
@@ -130,14 +130,14 @@ class InferenceModelConfig:
130130
engine_num: int = 1
131131
tensor_parallel_size: int = 1
132132
use_v1: bool = True
133-
max_prompt_tokens: int = 2048
134-
max_response_tokens: int = 2048
135133
enforce_eager: bool = True
136134
enable_prefix_caching: bool = False
137135
enable_chunked_prefill: bool = False
138136
gpu_memory_utilization: float = 0.9
139137
dtype: str = "bfloat16"
140138
seed: int = 42
139+
max_prompt_tokens: Optional[int] = None
140+
max_response_tokens: Optional[int] = None
141141
# override chat template in model
142142
chat_template: Optional[str] = None
143143
# For Qwen3
@@ -478,6 +478,10 @@ def check_and_update(self) -> None: # noqa: C901
478478
and self.explorer.rollout_model.enable_openai_api
479479
):
480480
raise ValueError("OpenAI API server only support `vllm_async` engine.")
481+
if self.explorer.rollout_model.max_prompt_tokens is None:
482+
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
483+
if self.explorer.rollout_model.max_response_tokens is None:
484+
self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
481485

482486
# check synchronizer
483487
self.synchronizer.explorer_world_size = (

trinity/common/models/vllm_async_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,17 @@ def __init__(
6464
)
6565
self.enable_thinking = config.enable_thinking
6666
self.request_id = 0
67+
max_model_len = None
68+
if config.max_prompt_tokens is not None and config.max_response_tokens is not None:
69+
max_model_len = config.max_prompt_tokens + config.max_response_tokens
6770
engine_args = vllm.AsyncEngineArgs(
6871
model=config.model_path,
6972
enforce_eager=config.enforce_eager,
7073
worker_extension_cls="trinity.common.models.vllm_worker.WorkerExtension",
7174
tensor_parallel_size=config.tensor_parallel_size,
7275
seed=config.seed,
7376
distributed_executor_backend=("uni" if config.tensor_parallel_size == 1 else "ray"),
74-
max_model_len=config.max_prompt_tokens + config.max_response_tokens,
77+
max_model_len=max_model_len,
7578
enable_prefix_caching=config.enable_prefix_caching,
7679
dtype=config.dtype,
7780
trust_remote_code=True,

0 commit comments

Comments
 (0)