Skip to content

Commit 5be7436

Browse files
authored
Async RL support multiple explorers (#100)
1 parent 3316e5b commit 5be7436

29 files changed

+390
-98
lines changed

docs/sphinx_doc/source/tutorial/example_async_mode.md

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
# Asynchronous RFT
22

3-
This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen2.5-1.5B-Instruct model and GSM8K dataset.
3+
This example demonstrates how to run RFT in fully asynchronous mode using the GRPO algorithm, Qwen2.5-1.5B-Instruct model, and GSM8K dataset.
44

5-
Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes.
5+
Trinity-RFT supports Asynchronous RFT by running the trainer and explorer in separate processes.
66

7-
For this purpose, we prepare two main config files: [`explorer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/explorer.yaml) and [`trainer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/trainer.yaml).
8-
The main difference between them is that in `explorer.yaml` we set `mode` as `explore`, while in `trainer.yaml` we set `mode` as `train`.
7+
For this purpose, we provide two main configuration files: [`explorer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/explorer.yaml) and [`trainer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/trainer.yaml).
8+
The primary difference between them is that in `explorer.yaml` we set `mode` as `explore`, while in `trainer.yaml` we set `mode` as `train`.
99
The model weights of the explorer and trainer are synchronized once every `sync_interval * batch_size` tasks.
1010

11-
Suppose we have a node of 8 GPUs; we use 4 GPUs for the trainer and 4 GPUs for the explorer.
12-
Some important setups of `explorer.yaml` are listed in the following:
11+
Assuming we have a node with 8 GPUs, we allocate 4 GPUs for the trainer and 4 GPUs for the explorer. Key configurations in `explorer.yaml` are as follows:
1312

1413
```yaml
14+
# explorer.yaml
1515
project: <project_name>
1616
name: <experiment_name>
1717
mode: explore
@@ -26,7 +26,7 @@ cluster:
2626
gpu_per_node: 4
2727
buffer:
2828
total_epochs: 1
29-
batch_size: 96
29+
batch_size: 64
3030
explorer_input:
3131
taskset:
3232
name: gsm8k
@@ -45,7 +45,6 @@ buffer:
4545
storage_type: queue
4646
path: 'sqlite:///gsm8k.db'
4747
explorer:
48-
eval_interval: 10
4948
runner_num: 32
5049
rollout_model:
5150
engine_type: vllm_async
@@ -57,9 +56,10 @@ trainer:
5756
trainer_config_path: examples/async_gsm8k/verl_config.yaml
5857
```
5958
60-
Some important setups of `trainer.yaml` are listed in the following:
59+
Key configurations in `trainer.yaml` are as follows:
6160

6261
```yaml
62+
# trainer.yaml
6363
project: <project_name>
6464
name: <experiment_name>
6565
mode: train
@@ -74,7 +74,7 @@ cluster:
7474
gpu_per_node: 4
7575
buffer:
7676
total_epochs: 1
77-
batch_size: 96
77+
batch_size: 64
7878
explorer_input:
7979
taskset:
8080
name: gsm8k
@@ -98,8 +98,7 @@ trainer:
9898
trainer_config_path: examples/async_gsm8k/verl_config.yaml
9999
```
100100

101-
102-
You may run this example with the following command:
101+
You can run this example with the following command:
103102

104103
```bash
105104
bash examples/async_gsm8k/run.sh
@@ -110,3 +109,60 @@ The following plot shows the learning curve of GRPO in the asynchronous mode.
110109
> We are continuously investigating other RL algorithms (e.g., [OPMD](./example_reasoning_advanced.md)) in the asynchronous mode.
111110

112111
![async](../../assets/async-curve.png)
112+
113+
114+
Trinity-RFT also supports dynamic scaling in asynchronous mode. Continuing with the previous example, if an additional machine with 8 GPUs joins the Ray cluster during training, you can launch a new explorer using the following configuration `explorer_new.yaml`.
115+
116+
```yaml
117+
# explorer_new.yaml
118+
project: <project_name>
119+
name: <experiment_name>
120+
mode: explore
121+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
122+
algorithm:
123+
algorithm_type: grpo
124+
repeat_times: 8
125+
model:
126+
model_path: /PATH/TO/MODEL/
127+
cluster: # important
128+
node_num: 1
129+
gpu_per_node: 8
130+
explorer:
131+
name: 'explorer_new' # important
132+
runner_num: 64
133+
rollout_model:
134+
engine_type: vllm_async
135+
engine_num: 8
136+
buffer:
137+
total_epochs: 1
138+
batch_size: 64
139+
explorer_input:
140+
taskset: # important
141+
name: gsm8k
142+
storage_type: file
143+
path: /PATH/TO/DATASET/
144+
format:
145+
prompt_key: 'question'
146+
response_key: 'answer'
147+
rollout_args:
148+
temperature: 1.0
149+
default_workflow_type: 'math_workflow'
150+
trainer_input:
151+
experience_buffer:
152+
name: gsm8k_buffer
153+
storage_type: queue
154+
path: 'sqlite:///gsm8k.db'
155+
synchronizer:
156+
sync_method: 'checkpoint'
157+
sync_interval: 10
158+
# other configs are the same as explorer.yaml
159+
```
160+
161+
The differences between `explorer_new.yaml` and `explorer.yaml` include:
162+
163+
- `cluster.node_num/gpu_per_node`: Specify the cluster configuration for the newly added explorer.
164+
- `explorer.name`: The later-started explorer requires a different name than "explorer", which is the default name for the existing explorer.
165+
- `explorer.rollout_model.engine_num/tensor_parallel_size`: Define the engine number and tensor parallel size to optimally utilize GPU resources.
166+
- `buffer.explorer_input.taskset`: Provide another task dataset as input for the new explorer.
167+
168+
All other parameters remain the same as in `explorer.yaml`.

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ checkpoint_root_dir: /PATH/TO/CHECKPOINT
6868
- `explore`: Only launches the explorer.
6969
- `bench`: Used for benchmarking.
7070
- `checkpoint_root_dir`: Root directory where all checkpoints and logs will be saved. Checkpoints for this experiment will be stored in `<checkpoint_root_dir>/<project>/<name>/`.
71+
- `ray_namespace`: Namespace for the modules launched in the current experiment. If not specified, it will be set to `<project>/<name>`.
7172

7273
---
7374

@@ -166,6 +167,9 @@ buffer:
166167
eval_tasksets:
167168
...
168169
170+
explorer_output:
171+
...
172+
169173
trainer_input:
170174
experience_buffer:
171175
...
@@ -219,15 +223,15 @@ buffer:
219223

220224
The configuration for each task dataset is defined as follows:
221225

222-
- `name`: Name of the dataset. Name must be unique.
226+
- `name`: Name of the dataset. This name will be used as the Ray actor's name, so it must be unique.
223227
- `storage_type`: How the dataset is stored. Options: `file`, `queue`, `sql`.
224228
- `file`: The dataset is stored in `jsonl`/`parquet` files. The data file organization is required to meet the huggingface standard. *We recommand using this storage type for most cases.*
225229
- `queue`: The dataset is stored in a queue. The queue is a simple FIFO queue that stores the task dataset. *Do not use this storage type for task dataset unless you know what you are doing.*
226230
- `sql`: The dataset is stored in a SQL database. *This type is unstable and will be optimized in the future versions.*
227231
- `path`: The path to the task dataset.
228-
- For `file` storage type, the path is the path to the directory that contains the task dataset files.
232+
- For `file` storage type, the path points to the directory that contains the task dataset files.
229233
- For `queue` storage type, the path is optional. You can back up the data in the queue by specifying a sqlite database path here.
230-
- For `sql` storage type, the path is the path to the sqlite database file.
234+
- For `sql` storage type, the path points to the sqlite database file.
231235
- `subset_name`: The subset name of the task dataset. Default is `None`.
232236
- `split`: The split of the task dataset. Default is `train`.
233237
- `format`: Defines keys for prompts and responses in the dataset.
@@ -240,6 +244,34 @@ The configuration for each task dataset is defined as follows:
240244
- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters.
241245

242246

247+
### Explorer Output
248+
249+
In [`explore` mode](#global-configuration), since there is no trainer, users can configure an experience buffer via `buffer.explorer_output`, rather than using `buffer.trainer_input`, which will be introduced in the next section.
250+
251+
> For `both` and `train` modes, users should use `buffer.trainer_input` instead of `buffer.explorer_output`.
252+
253+
```yaml
254+
buffer:
255+
...
256+
explorer_output:
257+
name: countdown_buffer
258+
storage_type: queue
259+
path: sqlite:///countdown_buffer.db
260+
wrap_in_ray: True
261+
```
262+
263+
- `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique.
264+
- `storage_type`: The storage type for the experience buffer.
265+
- `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases.
266+
- `sql`: Experience data is stored in a SQL database. If your database only supports local access (e.g., SQLite), set `wrap_in_ray` to `True` to wrap the database in a Ray actor, enabling remote access from other nodes.
267+
- `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode.
268+
- `path`: The path to the experience buffer.
269+
- For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data.
270+
- For `file` storage type, the path points to the directory containing the dataset files.
271+
- For `sql` storage type, the path points to the SQLite database file.
272+
- `wrap_in_ray`: Whether to wrap the experience buffer in a Ray actor. Only take effect when `storage_type` is `sql` or `file`. The `queue` storage always uses a Ray actor.
273+
274+
243275
### Trainer Input
244276

245277
Defines the experience buffer and optional SFT warm-up dataset.
@@ -264,7 +296,7 @@ buffer:
264296
sft_warmup_steps: 0
265297
```
266298

267-
- `experience_buffer`: Experience replay buffer used by the trainer.
299+
- `experience_buffer`: Experience buffer used by the trainer, which is logically equivalent to `buffer.explorer_output`.
268300
- `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup).
269301
- `sft_warmup_steps`: Number of steps to use SFT warm-up before RL begins.
270302

@@ -276,6 +308,7 @@ Controls the rollout models and workflow execution.
276308

277309
```yaml
278310
explorer:
311+
name: explorer
279312
runner_num: 32
280313
rollout_model:
281314
engine_type: vllm_async
@@ -286,11 +319,13 @@ explorer:
286319
tensor_parallel_size: 1
287320
```
288321

322+
- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique.
289323
- `runner_num`: Number of parallel workflow runners.
290324
- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`.
291325
- `rollout_model.engine_num`: Number of inference engines.
292326
- `rollout_model.tensor_parallel_size`: Degree of tensor parallelism.
293327
- `auxiliary_models`: Additional models used for custom workflows.
328+
294329
---
295330

296331
## Synchronizer Configuration
@@ -301,13 +336,15 @@ Controls how model weights are synchronized between trainer and explorer.
301336
synchronizer:
302337
sync_method: 'nccl'
303338
sync_interval: 10
339+
sync_offset: 0
304340
sync_timeout: 1200
305341
```
306342

307343
- `sync_method`: Method of synchronization. Options:
308344
- `nccl`: Uses NCCL for fast synchronization. Supported for `both` mode.
309345
- `checkpoint`: Loads latest model from disk. Supported for `train`, `explore`, or `bench` mode.
310346
- `sync_interval`: Interval (in steps) of model weight synchronization between trainer and explorer.
347+
- `sync_offset`: Offset (in steps) of model weight synchronization between trainer and explorer. The explorer can run `sync_offset` steps before the trainer starts training.
311348
- `sync_timeout`: Timeout duration for synchronization.
312349

313350
---
@@ -318,12 +355,14 @@ Specifies the backend and behavior of the trainer.
318355

319356
```yaml
320357
trainer:
358+
name: trainer
321359
trainer_type: 'verl'
322360
save_interval: 100
323361
trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml'
324362
trainer_config: null
325363
```
326364

365+
- `name`: Name of the trainer. This name will be used as the Ray actor's name, so it must be unique.
327366
- `trainer_type`: Trainer backend implementation. Currently only supports `verl`.
328367
- `save_interval`: Frequency (in steps) at which to save model checkpoints.
329368
- `trainer_config_path`: The path to the trainer configuration file.

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
This guide introduces how to develop new modules in Trinity-RFT and provides relevant development guidelines.
44

5-
Trinity-RFT consists of three main modules: **Explorer**, **Trainer** and **Buffer**.
6-
We decouple the RL pipeline into three modules to make it easier to customize and extend.
7-
Below is a table summarizing the modules and components that developers with different tragets need to focus on.
5+
In Trinity-RFT, we decompose the RL pipeline into three main modules (**Explorer**, **Trainer** and **Buffer**) to facilitate customization and extension.
6+
Below is a table summarizing the modules and components that developers with different targets need to focus on.
87

98
| Development Target | Core Module | Key Component |
109
|--------------------|-------------|---------------|

tests/buffer/file_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_file_buffer(self):
4747
# test writer
4848
writer = JSONWriter(meta, None)
4949
writer.write(data)
50-
writer.finish()
50+
writer.release()
5151

5252
# test reader
5353
meta.path = self.temp_output_path

tests/buffer/queue_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_queue_buffer(self):
3030
)
3131
writer = QueueWriter(meta, config)
3232
reader = QueueReader(meta, config)
33+
self.assertEqual(writer.acquire(), 1)
3334
exps = [
3435
Experience(
3536
tokens=torch.tensor([float(j) for j in range(i + 1)]),
@@ -59,7 +60,7 @@ def test_queue_buffer(self):
5960
)
6061
exps = reader.read(batch_size=put_batch_size * 2)
6162
self.assertEqual(len(exps), put_batch_size * 2)
62-
writer.finish()
63+
self.assertEqual(writer.release(), 0)
6364
self.assertRaises(StopIteration, reader.read)
6465
with open(BUFFER_FILE_PATH, "r") as f:
6566
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)

tests/buffer/sql_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def test_create_sql_buffer(self) -> None:
4242
)
4343
for i in range(1, put_batch_size + 1)
4444
]
45+
self.assertEqual(sql_writer.acquire(), 1)
4546
for _ in range(total_num // put_batch_size):
4647
sql_writer.write(exps)
4748
for _ in range(total_num // read_batch_size):
@@ -65,3 +66,5 @@ def test_create_sql_buffer(self) -> None:
6566
self.assertEqual(len(exps), put_batch_size * 2)
6667
db_wrapper = ray.get_actor("sql-test_buffer")
6768
self.assertIsNotNone(db_wrapper)
69+
self.assertEqual(sql_writer.release(), 0)
70+
self.assertRaises(StopIteration, sql_reader.read)

tests/explorer/runner_pool_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,6 @@ def test_runner_pool_with_auxiliary_models(self):
250250
)
251251

252252
# `auxiliary_models`
253-
st = time.time()
254253
status = pool.get_next_unorder()
255-
et = time.time()
256-
self.assertTrue(et - st < 1)
257254
self.assertEqual(len(status), 1)
258255
self.assertTrue(status[0].ok)

tests/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_checkpoint_path() -> str:
4242
def get_unittest_dataset_config(
4343
dataset_name: str = "countdown", split: str = "train"
4444
) -> StorageConfig:
45-
"""Countdown sample dataset for 8 steps"""
45+
"""Countdown dataset with 16 samples."""
4646
if dataset_name == "countdown" or dataset_name == "copy_countdown":
4747
return StorageConfig(
4848
name=dataset_name,

0 commit comments

Comments
 (0)