Skip to content

Commit 2267455

Browse files
authored
cancel all rollout eval (#1671)
1 parent 667fb38 commit 2267455

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

src/prime_rl/orchestrator/orchestrator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,10 @@ async def orchestrate(config: OrchestratorConfig):
308308
last_eval_step = ckpt_step
309309
logger.info(f"Running evals for checkpoint step {ckpt_step} (blocking, subprocess)")
310310

311-
# Pause weight updates during eval
311+
# Pause weight updates during eval and cancel inflight rollouts
312+
# this avoid doing eval across different checkpoints and avoid congestion
312313
scheduler.checkpoint_ready.clear()
314+
scheduler.cancel_all_inflight_rollouts()
313315

314316
await run_evals_subprocess(
315317
client_config=config.client,

src/prime_rl/orchestrator/scheduler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,18 @@ async def stop(self):
152152
for worker in workers:
153153
worker.stop()
154154

155+
def cancel_all_inflight_rollouts(self):
156+
"""Cancel all in-flight rollout requests.
157+
158+
Used when weights are updated to discard stale rollouts generated with old weights.
159+
"""
160+
count = len(self.inflight_group_rollouts)
161+
for future in list(self.inflight_group_rollouts.keys()):
162+
if not future.done():
163+
future.cancel()
164+
self.inflight_group_rollouts.clear()
165+
self.cancelled_rollouts_count += count
166+
155167
async def schedule_group_rollout(self):
156168
"""Asynchronously schedules a group rollout request."""
157169
example = self.buffer.sample_examples(n=1)[0]

0 commit comments

Comments
 (0)