Skip to content

Commit 73e8d8f

Browse files
authored
Split group of tasks to multiple runners (#116)
1 parent 21a9888 commit 73e8d8f

File tree

11 files changed

+163
-63
lines changed

11 files changed

+163
-63
lines changed

docs/sphinx_doc/source/tutorial/example_mix_algo.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ We also need to add an `is_expert_mask` field when transforming to DataProto to
137137
cumsum = torch.cumsum(attention_mask, dim=-1)
138138
position_ids = torch.clip(cumsum - 1, 0, None).long()
139139
batch_dict = {
140-
"uid": np.array(experiences.run_ids),
140+
"uid": np.array(experiences.group_ids),
141141
"position_ids": position_ids,
142142
"input_ids": experiences.tokens.long(),
143143
"responses": experiences.tokens[:, experiences.prompt_length :].long(),

tests/buffer/queue_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async def test_queue_buffer(self):
6868
with open(BUFFER_FILE_PATH, "r") as f:
6969
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)
7070
st = time.time()
71-
self.assertRaises(StopIteration, reader.read, batch_size=1)
71+
self.assertRaises(TimeoutError, reader.read, batch_size=1)
7272
et = time.time()
7373
self.assertTrue(et - st > 2)
7474

tests/explorer/scheduler_test.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tests.tools import get_template_config
1010
from trinity.buffer.reader.queue_reader import QueueReader
11-
from trinity.common.config import StorageConfig
11+
from trinity.common.config import GenerationConfig, StorageConfig
1212
from trinity.common.constants import StorageType
1313
from trinity.common.experience import Experience
1414
from trinity.common.models.model import InferenceModel
@@ -23,6 +23,7 @@ def __init__(self, model, task, auxiliary_models):
2323
super().__init__(model, task, auxiliary_models)
2424
self.error_type = task.raw_task.get("error_type", "")
2525
self.seconds = None
26+
self.repeat_times = task.rollout_args.n
2627
if "timeout" in self.error_type:
2728
parts = self.error_type.split("_")
2829
if len(parts) > 1:
@@ -42,8 +43,12 @@ def run(self) -> List[Experience]:
4243

4344
return [
4445
Experience(
45-
tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type or "success"
46+
tokens=torch.zeros(5),
47+
prompt_length=2,
48+
prompt_text=self.error_type or "success",
49+
info={"repeat_times": self.repeat_times},
4650
)
51+
for _ in range(self.repeat_times)
4752
]
4853

4954

@@ -98,7 +103,11 @@ def api_server_ready(self) -> Tuple[str, str]:
98103

99104

100105
def generate_tasks(
101-
total_num: int, timeout_num: int = 0, exception_num: int = 0, timeout_seconds: int = 10
106+
total_num: int,
107+
timeout_num: int = 0,
108+
exception_num: int = 0,
109+
timeout_seconds: int = 10,
110+
repeat_times: int = 1,
102111
):
103112
"""Generate some tasks for testing
104113
@@ -108,7 +117,10 @@ def generate_tasks(
108117
exception_num: number of exception tasks
109118
timeout_seconds: the timeout for timeout tasks
110119
"""
111-
tasks = [Task(workflow=DummyWorkflow, raw_task={}) for _ in range(total_num)]
120+
tasks = [
121+
Task(workflow=DummyWorkflow, raw_task={}, rollout_args=GenerationConfig(n=repeat_times))
122+
for _ in range(total_num)
123+
]
112124

113125
tasks.extend(
114126
[
@@ -150,6 +162,9 @@ def setUp(self):
150162
algorithm_type="ppo",
151163
path="",
152164
)
165+
self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 1
166+
self.config.algorithm.repeat_times = 1
167+
self.config.check_and_update()
153168
self.queue = QueueReader(
154169
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
155170
)
@@ -163,6 +178,9 @@ async def test_get_results(self):
163178

164179
results = await scheduler.get_results(batch_id=0, min_num=8, timeout=20)
165180
self.assertEqual(len(results), 8)
181+
self.assertEqual(len(self.queue.read(batch_size=8)), 8)
182+
with self.assertRaises(TimeoutError):
183+
self.queue.read(batch_size=1)
166184

167185
for result in results:
168186
self.assertTrue(result.ok)
@@ -176,13 +194,17 @@ async def test_get_results(self):
176194
results = await scheduler.get_results(batch_id=batch_id, min_num=4, timeout=10)
177195
self.assertEqual(len(results), 4)
178196
self.assertFalse(scheduler.has_step(batch_id))
197+
self.assertEqual(len(self.queue.read(batch_size=4)), 4)
198+
with self.assertRaises(TimeoutError):
199+
self.queue.read(batch_size=1)
179200

180201
tasks = generate_tasks(3)
181202
scheduler.schedule(tasks, batch_id=4)
182203
self.assertTrue(scheduler.has_step(4))
183204
results = await scheduler.get_results(batch_id=4)
184205
self.assertEqual(len(results), 3)
185206
self.assertFalse(scheduler.has_step(4))
207+
self.assertEqual(len(self.queue.read(batch_size=3)), 3)
186208

187209
# test timeout
188210
tasks = generate_tasks(2, timeout_num=2, timeout_seconds=10)
@@ -194,6 +216,7 @@ async def test_get_results(self):
194216

195217
self.assertLessEqual(end_time - start_time, 5)
196218
self.assertEqual(len(results), 2)
219+
self.assertEqual(len(self.queue.read(batch_size=2)), 2)
197220

198221
# test run tasks after timeout
199222
tasks = generate_tasks(4)
@@ -204,8 +227,10 @@ async def test_get_results(self):
204227
self.assertEqual(len(results), 4)
205228

206229
success_count = sum(1 for r in results if r.ok)
207-
208-
self.assertEqual(success_count, sum(1 for r in results if r.ok))
230+
self.assertEqual(success_count, 4)
231+
self.assertEqual(len(self.queue.read(batch_size=4)), 4)
232+
with self.assertRaises(TimeoutError):
233+
self.queue.read(batch_size=1)
209234

210235
# test exception tasks
211236
tasks = generate_tasks(1, exception_num=3)
@@ -215,14 +240,21 @@ async def test_get_results(self):
215240

216241
success_count = sum(1 for r in results if r.ok)
217242
self.assertEqual(success_count, 1)
243+
self.assertEqual(len(self.queue.read(batch_size=1)), 1)
244+
with self.assertRaises(TimeoutError):
245+
self.queue.read(batch_size=1)
218246

219247
# test clear_timeout_tasks
220248
tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3)
221249
scheduler.schedule(tasks, batch_id=2)
222250
results = await scheduler.get_results(batch_id=2, timeout=2, clear_timeout_tasks=False)
223251
self.assertEqual(len(results), 3)
252+
self.assertEqual(len(self.queue.read(batch_size=3)), 3)
224253
results = await scheduler.get_results(batch_id=2, timeout=2, clear_timeout_tasks=False)
225254
self.assertEqual(len(results), 1)
255+
self.assertEqual(len(self.queue.read(batch_size=1)), 1)
256+
with self.assertRaises(TimeoutError):
257+
self.queue.read(batch_size=1)
226258

227259
await scheduler.stop()
228260

@@ -366,6 +398,38 @@ async def test_scheduler_all_methods(self):
366398
self.assertFalse(scheduler.has_step(2))
367399
await scheduler.stop()
368400

401+
async def test_split_tasks(self):
402+
self.config.explorer.max_repeat_times_per_runner = 2
403+
self.config.check_and_update()
404+
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
405+
await scheduler.start()
406+
407+
tasks = generate_tasks(4, repeat_times=8) # ceil(8 / 2) == 4
408+
scheduler.schedule(tasks, batch_id=1)
409+
results = await scheduler.get_results(batch_id=1)
410+
self.assertEqual(len(results), 4 * 4)
411+
self.assertEqual(len(self.queue.read(batch_size=4 * 8)), 4 * 8)
412+
with self.assertRaises(TimeoutError):
413+
self.queue.read(batch_size=1)
414+
415+
tasks = generate_tasks(4, repeat_times=5) # ceil(5 / 2) == 3
416+
scheduler.schedule(tasks, batch_id=1)
417+
results = await scheduler.get_results(batch_id=1)
418+
self.assertEqual(len(results), 4 * 3)
419+
self.assertEqual(len(self.queue.read(batch_size=4 * 5)), 4 * 5)
420+
with self.assertRaises(TimeoutError):
421+
self.queue.read(batch_size=1)
422+
423+
tasks = generate_tasks(3, repeat_times=1) # ceil(1 / 2) == 1
424+
scheduler.schedule(tasks, batch_id=1)
425+
results = await scheduler.get_results(batch_id=1)
426+
self.assertEqual(len(results), 3 * 1)
427+
self.assertEqual(len(self.queue.read(batch_size=3 * 1)), 3 * 1)
428+
with self.assertRaises(TimeoutError):
429+
self.queue.read(batch_size=1)
430+
431+
await scheduler.stop()
432+
369433
def tearDown(self):
370434
try:
371435
ray.shutdown()

trinity/algorithm/sample_strategy/mix_sample_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) ->
9090
cumsum = torch.cumsum(attention_mask, dim=-1)
9191
position_ids = torch.clip(cumsum - 1, 0, None).long()
9292
batch_dict = {
93-
"uid": np.array(experiences.run_ids),
93+
"uid": np.array(experiences.group_ids),
9494
"position_ids": position_ids,
9595
"input_ids": experiences.tokens.long(),
9696
"responses": experiences.tokens[:, experiences.prompt_length :].long(),

trinity/algorithm/sample_strategy/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def to_data_proto(experiences: Experiences) -> DataProto:
1313
cumsum = torch.cumsum(attention_mask, dim=-1)
1414
position_ids = torch.clip(cumsum - 1, 0, None).long()
1515
batch_dict = {
16-
"uid": np.array(experiences.run_ids),
16+
"uid": np.array(experiences.group_ids),
1717
"position_ids": position_ids,
1818
"input_ids": experiences.tokens.long(),
1919
"responses": experiences.tokens[:, experiences.prompt_length :].long(),

trinity/buffer/reader/queue_reader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def read(
3131
batch_size = batch_size or self.read_batch_size
3232
exps = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout))
3333
if len(exps) != batch_size:
34-
raise StopIteration("Read incomplete batch, please check your workflow.")
34+
raise TimeoutError(
35+
f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow."
36+
)
3537
except StopAsyncIteration:
3638
raise StopIteration()
3739
return exps

trinity/common/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ class ExplorerConfig:
310310
max_timeout: int = 1800 # wait each task for 30 minutes
311311
max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout
312312
env_vars: dict = field(default_factory=dict) # environment variables for workflow runner
313+
max_repeat_times_per_runner: Optional[
314+
int
315+
] = None # the number of time to repeat each task in a single workflow runner (for GRPO-like algorithms)
313316

314317
runner_num: Optional[int] = None # deprecated
315318

@@ -358,9 +361,8 @@ class MonitorConfig:
358361

359362
@dataclass
360363
class SynchronizerConfig:
361-
"""Configs for model weight synchronization"""
364+
"""Configs for model weight synchronization."""
362365

363-
# TODO: rename to "checkpoint", "nccl", "ipc"
364366
sync_method: SyncMethod = SyncMethod.NCCL
365367
# sync weights every `sync_interval` steps
366368
sync_interval: int = 1

trinity/common/experience.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Experience:
2626
rejected: Optional[Tensor] = None # for dpo
2727
info: Optional[dict] = None
2828
metrics: Optional[dict[str, float]] = None
29-
run_id: str = ""
29+
group_id: str = "" # for grpo
3030

3131
def __post_init__(self):
3232
if self.action_mask is not None:
@@ -95,7 +95,7 @@ class Experiences:
9595
action_masks: Optional[Tensor]
9696
prompt_length: int
9797
logprobs: Optional[Tensor]
98-
run_ids: List[str]
98+
group_ids: List[str]
9999

100100
@property
101101
def batch_size(self) -> int:
@@ -118,11 +118,11 @@ def gather_experiences(
118118
action_masks=torch.empty(0, dtype=torch.bool),
119119
logprobs=torch.empty(0, dtype=torch.float32),
120120
prompt_length=torch.empty(0, dtype=torch.int32),
121-
run_ids=[],
121+
group_ids=[],
122122
)
123123
max_prompt_length = max([exp.prompt_length for exp in experiences])
124124
max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences])
125-
run_ids = [exp.run_id for exp in experiences]
125+
group_ids = [exp.group_id for exp in experiences]
126126
tokens_dtype = experiences[0].tokens.dtype
127127
tokens = torch.stack(
128128
[
@@ -208,7 +208,7 @@ def gather_experiences(
208208
logprobs = None
209209

210210
return cls(
211-
run_ids=run_ids,
211+
group_ids=group_ids,
212212
tokens=tokens,
213213
rewards=rewards,
214214
attention_masks=attention_masks,
@@ -249,7 +249,7 @@ def gather_dpo_experiences(
249249
action_masks=torch.empty(0, dtype=torch.bool),
250250
logprobs=torch.empty(0, dtype=torch.float32),
251251
prompt_length=torch.empty(0, dtype=torch.int32),
252-
run_ids=[],
252+
group_ids=[],
253253
)
254254

255255
# TODO: exp.tokens in DPO are prompt tokens
@@ -261,7 +261,7 @@ def gather_dpo_experiences(
261261
response_tokens = list(chain.from_iterable(zip(chosen_tokens, rejected_tokens)))
262262
max_response_length = max([len(response) for response in response_tokens]) # type: ignore
263263

264-
run_ids = list(chain.from_iterable([repeat(exp.run_id, 2) for exp in experiences]))
264+
group_ids = list(chain.from_iterable([repeat(exp.group_id, 2) for exp in experiences]))
265265
tokens_dtype = experiences[0].tokens.dtype
266266
tokens = torch.stack(
267267
[
@@ -297,7 +297,7 @@ def gather_dpo_experiences(
297297
assert len(tokens) == 2 * len(experiences)
298298

299299
return cls(
300-
run_ids=run_ids,
300+
group_ids=group_ids,
301301
tokens=tokens,
302302
attention_masks=attention_masks,
303303
prompt_length=max_prompt_length,

trinity/common/workflows/workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class Task:
3535
reward_fn: Optional[Type[RewardFn]] = None
3636
raw_task: Optional[dict] = None # The raw data sample
3737

38+
group_id: Optional[str] = None # for GRPO-like algorithms, automatically assigned
39+
3840
def to_workflow(
3941
self, model: Any, auxiliary_models: Optional[List[openai.OpenAI]] = None
4042
) -> Workflow:

0 commit comments

Comments
 (0)