Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ async def main(cfg: DictConfig):
),
)

if "steps" not in cfg.trainer.training:
raise ValueError("`cfg.trainer.training.steps` must be defined (can be null).")

max_steps = cfg.trainer.training.steps

print("All services initialized successfully!")

# ---- Core RL loops ---- #
Expand Down Expand Up @@ -433,7 +438,6 @@ async def continuous_rollouts():
async def continuous_training():
training_step = 0
restart_tracer = True # Flag to control when to restart tracer
max_steps = cfg.trainer.training.get("steps", None)

while max_steps is None or training_step < max_steps:
Copy link
Contributor

@casteryh casteryh Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make `cfg.training.steps' a required argument. And explicitly setting to -1 means run until interrupted.

Copy link
Member Author

@DNXie DNXie Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made max_steps

max_steps = cfg.trainer.training.steps or -1

and the condition

while max_steps < 0 or training_step < max_steps:

# Restart tracer when needed (initial start or after completing a training step)
Expand Down Expand Up @@ -492,11 +496,12 @@ async def continuous_training():
print("Training interrupted by user")
finally:
print("Shutting down...")

for rollout_task in rollout_tasks:
rollout_task.cancel()
# graceful await all tasks, ignore cancellation noise
await asyncio.gather(*rollout_tasks, return_exceptions=True)
# Give replicas time to drain and complete in-flight requests
await asyncio.sleep(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can/should make this part more graceful. Proposal:

Continuous rollouts takes a shutdown event:

async def continuous_rollouts(shutdown_evnt: asyncio.Event):
    ...
    while not shutdown_event.is_set(): # no more while True
        ...

    print("Rollout loop got shutdown event, shutting down...")

then in our finally we can do a 2-phased shutdown:

finally:
    print("Shutting down...")
    shutdown_event.set()

    try:
        # give tasks a chance to exit gracefully
        await asyncio.wait_for(
            asyncio.gather(*rollout_tasks, return_exceptions=True),
            timeout=5
        )
    except asyncio.TimeoutError:
        print("Forcing cancellation...")
        for t in rollout_tasks:
            t.cancel()
        await asyncio.gather(*rollout_tasks, return_exceptions=True)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

training_task.cancel()

# give mlogger time to shutdown backends, otherwise they can stay running.
Expand Down
5 changes: 5 additions & 0 deletions src/forge/controller/service/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ async def _process_single_request(self, request: ServiceRequest) -> bool:
# can be healthy but the request failed.
self.mark_failed()
success = False
except asyncio.InvalidStateError:
# Future was already cancelled — safe to ignore during shutdown
self.mark_failed()
success = False
pass
except Exception as e:
logger.debug(f"Got unexpected error on replica {self.idx}. Error:\n{e}")
self.mark_failed()
Expand Down
Loading