-
Notifications
You must be signed in to change notification settings - Fork 17
Training termination for grpo/main with finite step limit #376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
# Flush metrics every training step to WandB | ||
await mlogger.flush.call_one(training_step) | ||
|
||
print( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we get the logger from src/forge/util/logging.py
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't really use the logger elsewhere in the main script do we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The entire main script is using print instead of logger right now.
apps/grpo/main.py
Outdated
async def continuous_training(): | ||
training_step = 0 | ||
restart_tracer = True # Flag to control when to restart tracer | ||
max_steps = cfg.trainer.training.get("steps", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put this at the top of the main() loop? And also make it required please. The default can still be null / None, but this way it's very visible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
training_task = asyncio.create_task(continuous_training()) | ||
|
||
try: | ||
await asyncio.gather(*rollout_tasks, training_task) | ||
await training_task | ||
except KeyboardInterrupt: | ||
print("Training interrupted by user") | ||
finally: | ||
print("Shutting down...") | ||
|
||
for rollout_task in rollout_tasks: | ||
rollout_task.cancel() | ||
# graceful await all tasks, ignore cancellation noise | ||
await asyncio.gather(*rollout_tasks, return_exceptions=True) | ||
training_task.cancel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain this change a bit? before we would do gather in the try, and they would be impacted by KeyboardInterrupt
. Now the gather happens after KeyboardInterrupt
. Would the user possibly have to run 'KeyboardInterrupt' twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain this change a bit? before we would do gather in the try, and they would be impacted by
KeyboardInterrupt
. Now the gather happens afterKeyboardInterrupt
. Would the user possibly have to run 'KeyboardInterrupt' twice?
I think after the first KeyboardInterrupt, all the tasks are canceled and now we just gather on the canceled tasks (which should be fast to resolve) for graceful shutdown.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous gather(*rollout_tasks, training_task)
call blocked on rollouts indefinitely, even after training_task
completed (e.g., once max_steps
was reached).
Since we now have a finite step limit that cleanly terminates training_task
, we shouldn’t continue waiting on rollout tasks.
With this change, rollout termination is handled explicitly in the finally
block.
This doesn’t theoretically change how KeyboardInterrupt
is handled, since an interrupt caught in the except
would still flow into finally
.
However, we’ve already seen issues with KeyboardInterrupt
handling not working properly (see #360 (2)); I’ll look into that separately in a follow-up PR.
apps/grpo/main.py
Outdated
restart_tracer = True # Flag to control when to restart tracer | ||
|
||
while True: | ||
while max_steps is None or training_step < max_steps: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should make `cfg.training.steps' a required argument. And explicitly setting to -1 means run until interrupted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made max_steps
max_steps = cfg.trainer.training.steps or -1
and the condition
while max_steps < 0 or training_step < max_steps:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
apps/grpo/main.py
Outdated
# graceful await all tasks, ignore cancellation noise | ||
await asyncio.gather(*rollout_tasks, return_exceptions=True) | ||
# Give replicas time to drain and complete in-flight requests | ||
await asyncio.sleep(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can/should make this part more graceful. Proposal:
Continuous rollouts takes a shutdown event:
async def continuous_rollouts(shutdown_evnt: asyncio.Event):
...
while not shutdown_event.is_set(): # no more while True
...
print("Rollout loop got shutdown event, shutting down...")
then in our finally
we can do a 2-phased shutdown:
finally:
print("Shutting down...")
shutdown_event.set()
try:
# give tasks a chance to exit gracefully
await asyncio.wait_for(
asyncio.gather(*rollout_tasks, return_exceptions=True),
timeout=5
)
except asyncio.TimeoutError:
print("Forcing cancellation...")
for t in rollout_tasks:
t.cancel()
await asyncio.gather(*rollout_tasks, return_exceptions=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM as long as it's still runnig correctly!
Previously the GRPO loop ran indefinitely with both training and rollout tasks active.
This PR adds a
cfg.training.steps
limit so training stops after the specified number of steps, then cleanly terminates rollout tasks.Now logs end with messages like (I tested with
cfg.training.steps=2
):