Skip to content

Commit 6d84662

Browse files
authored
fix bug in fully async example (issue #488) (#519)
1 parent 7b2701c commit 6d84662

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

examples/fully_async/fully_async_rollout.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,23 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> List[Lis
193193

194194
group = completed_groups.pop(group_id)
195195

196+
# If any sample in the group was aborted, return the whole group to the data buffer
197+
# and do not forward it to the training engine.
198+
try:
199+
any_aborted = any([sample.status == Sample.Status.ABORTED for sample in group])
200+
except Exception:
201+
any_aborted = False
202+
203+
if any_aborted:
204+
try:
205+
# add back to buffer so it can be retried or handled by buffer policy
206+
data_buffer.add_samples([group])
207+
print(f"Returned aborted group {group_id} to data buffer", flush=True)
208+
except Exception as e:
209+
print(f"Failed to return aborted group {group_id} to buffer: {e}", flush=True)
210+
# don't count as processed for training
211+
continue
212+
196213
if do_print:
197214
print(
198215
f"First rollout sample: {[group[0].prompt + group[0].response]}, "

examples/fully_async/run-qwen3-4b-fully_async.sh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,19 @@ CKPT_ARGS=(
3535
--save-interval 20
3636
)
3737

38+
PROMPT_SET=/path/to/dapo-math-17k.jsonl
39+
3840
ROLLOUT_ARGS=(
3941
--rollout-function-path fully_async_rollout.generate_rollout_fully_async
40-
--prompt-data /mnt/o1_alicloud/personal/zzl/rl_data/dapo-math-17k.jsonl
42+
--prompt-data ${PROMPT_SET}
4143
--input-key prompt
4244
--label-key label
4345
--apply-chat-template
4446
--rollout-shuffle
45-
--rm-type deepscaler
47+
48+
--rm-type dapo
49+
--reward-key score
50+
4651
--num-rollout 3000
4752
--rollout-batch-size 32
4853
--n-samples-per-prompt 8

slime/rollout/rm_hub/math_dapo_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ def __exit__(self, type, value, traceback):
136136
"{,}",
137137
'"',
138138
"\\dots",
139+
"<|im_end|>",
140+
"<|endoftext|>",
139141
]
140142

141143

@@ -206,6 +208,8 @@ def is_correct_minerva(
206208
else:
207209
gt = normalize_final_answer(gt)
208210

211+
gt = str(int(float(gt))) # in dapo, all answers are integers
212+
209213
return (pred == gt), pred
210214

211215

0 commit comments

Comments
 (0)