Skip to content

Conversation

DNXie
Copy link
Member

@DNXie DNXie commented Oct 10, 2025

Previously the GRPO loop ran indefinitely with both training and rollout tasks active.
This PR adds a cfg.training.steps limit so training stops after the specified number of steps, then cleanly terminates rollout tasks.

Now logs end with messages like (I tested with cfg.training.steps=2):

Reached training limit (2 steps). Exiting continuous_training loop.
Shutting down...
WandbBackend global_controller: Finished run
Health loop stopped gracefully.
Health loop stopped gracefully.
Health loop stopped gracefully.
... 
Shutting down provisioner..

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 10, 2025
@DNXie DNXie requested a review from Jack-Khuu October 10, 2025 18:45
# Flush metrics every training step to WandB
await mlogger.flush.call_one(training_step)

print(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get the logger from src/forge/util/logging.py instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really use the logger elsewhere in the main script do we?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire main script is using print instead of logger right now.

async def continuous_training():
training_step = 0
restart_tracer = True # Flag to control when to restart tracer
max_steps = cfg.trainer.training.get("steps", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put this at the top of the main() loop? And also make it required please. The default can still be null / None, but this way it's very visible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines 487 to 500
training_task = asyncio.create_task(continuous_training())

try:
await asyncio.gather(*rollout_tasks, training_task)
await training_task
except KeyboardInterrupt:
print("Training interrupted by user")
finally:
print("Shutting down...")

for rollout_task in rollout_tasks:
rollout_task.cancel()
# graceful await all tasks, ignore cancellation noise
await asyncio.gather(*rollout_tasks, return_exceptions=True)
training_task.cancel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this change a bit? before we would do gather in the try, and they would be impacted by KeyboardInterrupt. Now the gather happens after KeyboardInterrupt. Would the user possibly have to run 'KeyboardInterrupt' twice?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this change a bit? before we would do gather in the try, and they would be impacted by KeyboardInterrupt. Now the gather happens after KeyboardInterrupt. Would the user possibly have to run 'KeyboardInterrupt' twice?

I think after the first KeyboardInterrupt, all the tasks are canceled and now we just gather on the canceled tasks (which should be fast to resolve) for graceful shutdown.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous gather(*rollout_tasks, training_task) call blocked on rollouts indefinitely, even after training_task completed (e.g., once max_steps was reached).
Since we now have a finite step limit that cleanly terminates training_task, we shouldn’t continue waiting on rollout tasks.

With this change, rollout termination is handled explicitly in the finally block.
This doesn’t theoretically change how KeyboardInterrupt is handled, since an interrupt caught in the except would still flow into finally.
However, we’ve already seen issues with KeyboardInterrupt handling not working properly (see #360 (2)); I’ll look into that separately in a follow-up PR.

restart_tracer = True # Flag to control when to restart tracer

while True:
while max_steps is None or training_step < max_steps:
Copy link
Contributor

@casteryh casteryh Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make `cfg.training.steps' a required argument. And explicitly setting to -1 means run until interrupted.

Copy link
Member Author

@DNXie DNXie Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made max_steps

max_steps = cfg.trainer.training.steps or -1

and the condition

while max_steps < 0 or training_step < max_steps:

# graceful await all tasks, ignore cancellation noise
await asyncio.gather(*rollout_tasks, return_exceptions=True)
# Give replicas time to drain and complete in-flight requests
await asyncio.sleep(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can/should make this part more graceful. Proposal:

Continuous rollouts takes a shutdown event:

async def continuous_rollouts(shutdown_evnt: asyncio.Event):
    ...
    while not shutdown_event.is_set(): # no more while True
        ...

    print("Rollout loop got shutdown event, shutting down...")

then in our finally we can do a 2-phased shutdown:

finally:
    print("Shutting down...")
    shutdown_event.set()

    try:
        # give tasks a chance to exit gracefully
        await asyncio.wait_for(
            asyncio.gather(*rollout_tasks, return_exceptions=True),
            timeout=5
        )
    except asyncio.TimeoutError:
        print("Forcing cancellation...")
        for t in rollout_tasks:
            t.cancel()
        await asyncio.gather(*rollout_tasks, return_exceptions=True)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@DNXie DNXie requested a review from allenwang28 October 13, 2025 19:18
Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as long as it's still runnig correctly!

@DNXie DNXie merged commit 06a0ae7 into meta-pytorch:main Oct 13, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants