-
Notifications
You must be signed in to change notification settings - Fork 17
Training termination for grpo/main with finite step limit #376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
ef013c3
e966d5c
bb711fe
36cc2da
ea9d09c
3b61452
283f305
aeef847
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -433,8 +433,9 @@ async def continuous_rollouts(): | |
async def continuous_training(): | ||
training_step = 0 | ||
restart_tracer = True # Flag to control when to restart tracer | ||
max_steps = cfg.trainer.training.get("steps", None) | ||
|
||
while True: | ||
while max_steps is None or training_step < max_steps: | ||
|
||
# Restart tracer when needed (initial start or after completing a training step) | ||
# Otherwise, we cannot measure time waiting for buffer | ||
if restart_tracer: | ||
|
@@ -471,6 +472,10 @@ async def continuous_training(): | |
# Flush metrics every training step to WandB | ||
await mlogger.flush.call_one(training_step) | ||
|
||
print( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we get the logger from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't really use the logger elsewhere in the main script do we? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The entire main script is using print instead of logger right now. |
||
f"Reached training limit ({max_steps} steps). Exiting continuous_training loop." | ||
) | ||
|
||
num_rollout_threads = cfg.get("rollout_threads", 1) | ||
num_training_threads = cfg.get("training_threads", 1) | ||
print( | ||
|
@@ -482,14 +487,17 @@ async def continuous_training(): | |
training_task = asyncio.create_task(continuous_training()) | ||
|
||
try: | ||
await asyncio.gather(*rollout_tasks, training_task) | ||
await training_task | ||
except KeyboardInterrupt: | ||
print("Training interrupted by user") | ||
finally: | ||
print("Shutting down...") | ||
|
||
for rollout_task in rollout_tasks: | ||
rollout_task.cancel() | ||
# graceful await all tasks, ignore cancellation noise | ||
await asyncio.gather(*rollout_tasks, return_exceptions=True) | ||
training_task.cancel() | ||
Comment on lines
515
to
537
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain this change a bit? before we would do gather in the try, and they would be impacted by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think after the first KeyboardInterrupt, all the tasks are canceled and now we just gather on the canceled tasks (which should be fast to resolve) for graceful shutdown. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous With this change, rollout termination is handled explicitly in the |
||
finally: | ||
print("Shutting down...") | ||
|
||
# give mlogger time to shutdown backends, otherwise they can stay running. | ||
# TODO (felipemello) find more elegant solution | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put this at the top of the main() loop? And also make it required please. The default can still be null / None, but this way it's very visible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!