Skip to content

Commit 21a9888

Browse files
authored
Add new scheduler with step granularity (#110)
1 parent 6eb2a44 commit 21a9888

File tree

25 files changed

+1056
-754
lines changed

25 files changed

+1056
-754
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ Controls the rollout models and workflow execution.
314314
```yaml
315315
explorer:
316316
name: explorer
317-
runner_num: 32
317+
runner_per_model: 8
318318
max_timeout: 900
319319
max_retry_times: 2
320320
env_vars: {}
@@ -325,17 +325,22 @@ explorer:
325325
auxiliary_models:
326326
- model_path: /PATH/TO/MODEL
327327
tensor_parallel_size: 1
328+
eval_interval: 100
329+
eval_on_startup: True
328330
```
329331

330332
- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique.
331-
- `runner_num`: Number of parallel workflow runners.
333+
- `runner_per_model`: Number of parallel workflow runners per each rollout model.
332334
- `max_timeout`: Maximum time (in seconds) for a workflow to complete.
333335
- `max_retry_times`: Maximum number of retries for a workflow.
334336
- `env_vars`: Environment variables to be set for every workflow runners.
335337
- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`.
336338
- `rollout_model.engine_num`: Number of inference engines.
337339
- `rollout_model.tensor_parallel_size`: Degree of tensor parallelism.
338340
- `auxiliary_models`: Additional models used for custom workflows.
341+
- `eval_interval`: Interval (in steps) for evaluating the model.
342+
- `eval_on_startup`: Whether to evaluate the model on startup. More precisely, at step 0 with the original model, so it will not be triggered when restarting.
343+
- `runner_num`: (*Deprecated*) Number of parallel workflow runners.
339344

340345
---
341346

examples/grpo_alfworld/alfworld.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ buffer:
3434
path: 'sqlite:///alfworld.db'
3535
explorer:
3636
runner_num: 32
37+
max_timeout: 3600
3738
rollout_model:
3839
engine_type: vllm_async
3940
engine_num: 2
@@ -44,10 +45,12 @@ explorer:
4445
seed: 42
4546
gpu_memory_utilization: 0.7
4647
enable_chunked_prefill: true
48+
env_vars:
49+
TMPDIR: /PATH/TO/ALFWORLD_TMP_DIR
4750
synchronizer:
4851
sync_method: 'nccl'
49-
sync_interval: 8
50-
sync_timeout: 1200
52+
sync_interval: 5
53+
sync_timeout: 3600
5154
trainer:
5255
trainer_type: 'verl'
5356
trainer_config_path: 'examples/grpo_alfworld/train_alfworld.yaml'

tests/buffer/file_test.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from trinity.common.constants import StorageType
1717

1818

19-
class TestFileBuffer(unittest.TestCase):
19+
class TestFileBuffer(unittest.IsolatedAsyncioTestCase):
2020
temp_output_path = "tmp/test_file_buffer/"
2121

2222
@classmethod
@@ -30,7 +30,7 @@ def tearDownClass(cls):
3030
if os.path.exists(cls.temp_output_path):
3131
os.system(f"rm -rf {cls.temp_output_path}")
3232

33-
def test_file_buffer(self):
33+
async def test_file_buffer(self):
3434
meta = StorageConfig(
3535
name="test_buffer",
3636
path=os.path.join(self.temp_output_path, "buffer.jsonl"),
@@ -46,8 +46,9 @@ def test_file_buffer(self):
4646

4747
# test writer
4848
writer = JSONWriter(meta, None)
49+
await writer.acquire()
4950
writer.write(data)
50-
writer.release()
51+
await writer.release()
5152

5253
# test reader
5354
meta.path = self.temp_output_path
@@ -119,23 +120,31 @@ def test_file_reader(self): # noqa: C901
119120
break
120121
self.assertEqual(len(tasks), 40 - 24)
121122

122-
def test_file_writer(self):
123+
async def test_file_writer(self):
123124
writer = get_buffer_writer(
124125
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
125126
)
127+
await writer.acquire()
126128
writer.write(
127129
[
128130
{"prompt": "hello world"},
129131
{"prompt": "hi"},
130132
]
131133
)
134+
await writer.write_async(
135+
[
136+
{"prompt": "My name is"},
137+
{"prompt": "What is your name?"},
138+
]
139+
)
140+
await writer.release()
132141
file_wrapper = ray.get_actor("json-test_buffer")
133142
self.assertIsNotNone(file_wrapper)
134143
file_path = default_storage_path(
135144
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
136145
)
137146
with open(file_path, "r") as f:
138-
self.assertEqual(len(f.readlines()), 2)
147+
self.assertEqual(len(f.readlines()), 4)
139148

140149
def setUp(self):
141150
self.config = get_template_config()

tests/buffer/queue_test.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
2+
import threading
23
import time
34

45
import torch
56

6-
from tests.tools import RayUnittestBase
7+
from tests.tools import RayUnittestBaseAysnc
78
from trinity.buffer.reader.queue_reader import QueueReader
89
from trinity.buffer.writer.queue_writer import QueueWriter
910
from trinity.common.config import BufferConfig, StorageConfig
@@ -13,8 +14,8 @@
1314
BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")
1415

1516

16-
class TestQueueBuffer(RayUnittestBase):
17-
def test_queue_buffer(self):
17+
class TestQueueBuffer(RayUnittestBaseAysnc):
18+
async def test_queue_buffer(self):
1819
total_num = 8
1920
put_batch_size = 2
2021
read_batch_size = 4
@@ -32,7 +33,7 @@ def test_queue_buffer(self):
3233
)
3334
writer = QueueWriter(meta, config)
3435
reader = QueueReader(meta, config)
35-
self.assertEqual(writer.acquire(), 1)
36+
self.assertEqual(await writer.acquire(), 1)
3637
exps = [
3738
Experience(
3839
tokens=torch.tensor([float(j) for j in range(i + 1)]),
@@ -43,7 +44,7 @@ def test_queue_buffer(self):
4344
for i in range(1, put_batch_size + 1)
4445
]
4546
for _ in range(total_num // put_batch_size):
46-
writer.write(exps)
47+
await writer.write_async(exps)
4748
for _ in range(total_num // read_batch_size):
4849
exps = reader.read()
4950
self.assertEqual(len(exps), read_batch_size)
@@ -62,7 +63,7 @@ def test_queue_buffer(self):
6263
)
6364
exps = reader.read(batch_size=put_batch_size * 2)
6465
self.assertEqual(len(exps), put_batch_size * 2)
65-
self.assertEqual(writer.release(), 0)
66+
self.assertEqual(await writer.release(), 0)
6667
self.assertRaises(StopIteration, reader.read)
6768
with open(BUFFER_FILE_PATH, "r") as f:
6869
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)
@@ -71,6 +72,34 @@ def test_queue_buffer(self):
7172
et = time.time()
7273
self.assertTrue(et - st > 2)
7374

75+
# test queue capacity
76+
meta = StorageConfig(
77+
name="test_buffer_small",
78+
algorithm_type="ppo",
79+
storage_type=StorageType.QUEUE,
80+
max_read_timeout=3,
81+
capacity=4,
82+
path=BUFFER_FILE_PATH,
83+
)
84+
writer = QueueWriter(meta, config)
85+
reader = QueueReader(meta, config)
86+
writer.write([{"content": "hello"}])
87+
writer.write([{"content": "hi"}])
88+
writer.write([{"content": "hello"}])
89+
writer.write([{"content": "hi"}])
90+
91+
# should be blocked
92+
def write_blocking_call():
93+
writer.write([{"content": "blocked"}])
94+
95+
thread = threading.Thread(target=write_blocking_call)
96+
thread.start()
97+
thread.join(timeout=2)
98+
self.assertTrue(thread.is_alive(), "write() did not block as expected")
99+
reader.read()
100+
thread.join(timeout=1)
101+
self.assertFalse(thread.is_alive())
102+
74103
def setUp(self):
75104
if os.path.exists(BUFFER_FILE_PATH):
76105
os.remove(BUFFER_FILE_PATH)

tests/buffer/sql_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
2-
import unittest
32

43
import ray
54
import torch
65

6+
from tests.tools import RayUnittestBaseAysnc
77
from trinity.buffer.reader.sql_reader import SQLReader
88
from trinity.buffer.writer.sql_writer import SQLWriter
99
from trinity.common.config import BufferConfig, StorageConfig
@@ -13,8 +13,8 @@
1313
db_path = os.path.join(os.path.dirname(__file__), "test.db")
1414

1515

16-
class TestSQLBuffer(unittest.TestCase):
17-
def test_create_sql_buffer(self) -> None:
16+
class TestSQLBuffer(RayUnittestBaseAysnc):
17+
async def test_create_sql_buffer(self) -> None:
1818
total_num = 8
1919
put_batch_size = 2
2020
read_batch_size = 4
@@ -42,9 +42,9 @@ def test_create_sql_buffer(self) -> None:
4242
)
4343
for i in range(1, put_batch_size + 1)
4444
]
45-
self.assertEqual(sql_writer.acquire(), 1)
45+
self.assertEqual(await sql_writer.acquire(), 1)
4646
for _ in range(total_num // put_batch_size):
47-
sql_writer.write(exps)
47+
await sql_writer.write_async(exps)
4848
for _ in range(total_num // read_batch_size):
4949
exps = sql_reader.read()
5050
self.assertEqual(len(exps), read_batch_size)
@@ -66,5 +66,5 @@ def test_create_sql_buffer(self) -> None:
6666
self.assertEqual(len(exps), put_batch_size * 2)
6767
db_wrapper = ray.get_actor("sql-test_buffer")
6868
self.assertIsNotNone(db_wrapper)
69-
self.assertEqual(sql_writer.release(), 0)
69+
self.assertEqual(await sql_writer.release(), 0)
7070
self.assertRaises(StopIteration, sql_reader.read)

0 commit comments

Comments
 (0)