Skip to content

Commit 8fa0dde

Browse files
mikasenghaasclaude
andauthored
feat: make max_retries configurable per training env (#2025)
Move max_retries from EvalEnvConfig to EnvConfig so it applies to both training and eval environments. The scheduler now looks up max_retries per task from the env config instead of hardcoding 0. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7694668 commit 8fa0dde

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

src/prime_rl/configs/orchestrator.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,13 @@ class EnvConfig(BaseConfig):
287287
),
288288
),
289289
] = {}
290+
max_retries: Annotated[
291+
int,
292+
Field(
293+
ge=0,
294+
description="Maximum number of times the environment will retry a failed rollout.",
295+
),
296+
] = 0
290297

291298
@property
292299
def resolved_name(self) -> str:
@@ -324,17 +331,6 @@ class EvalEnvConfig(EnvConfig):
324331
),
325332
] = 0
326333

327-
# TODO: should live on the EnvConfig and also apply to training envs but
328-
# this is hard right now because we use the vf.EnvGroup which treats all
329-
# envs as one. for now training envs hardcode no retries, but we should
330-
# probably treat them like environment groups long-term
331-
max_retries: Annotated[
332-
int,
333-
Field(
334-
description="Maximum number of times the environment will try to retry running a rollout.",
335-
),
336-
] = 0
337-
338334

339335
class ValConfig(BaseConfig):
340336
"""Configures the validation of the model."""

src/prime_rl/orchestrator/scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(
9393
# Inference pool - used for admin operations (adapter sync) and metrics
9494
self.inference_pool = inference_pool
9595

96+
self.max_retries_by_task = {env.resolved_name: env.max_retries for env in config.env}
9697
self.deferred_group_scoring_tasks = set(deferred_group_scoring_tasks or ())
9798
if self.deferred_group_scoring_tasks:
9899
task_list = ", ".join(sorted(self.deferred_group_scoring_tasks))
@@ -203,7 +204,7 @@ async def schedule_rollout(self, group_id: int):
203204
example=group.example,
204205
model_name=self.model_name,
205206
sampling_args=self.sampling_args,
206-
max_retries=0, # TODO: make configurable
207+
max_retries=self.max_retries_by_task.get(group.example["task"], 0),
207208
)
208209
)
209210
self.inflight_requests[run_rollout_task] = InflightRolloutInfo(

0 commit comments

Comments
 (0)