Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,9 @@ async def continuous_rollouts():
async def continuous_training():
training_step = 0
restart_tracer = True # Flag to control when to restart tracer
max_steps = cfg.trainer.training.get("steps", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put this at the top of the main() loop? And also make it required please. The default can still be null / None, but this way it's very visible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


while True:
while max_steps is None or training_step < max_steps:
Copy link
Contributor

@casteryh casteryh Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make `cfg.training.steps' a required argument. And explicitly setting to -1 means run until interrupted.

Copy link
Member Author

@DNXie DNXie Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made max_steps

max_steps = cfg.trainer.training.steps or -1

and the condition

while max_steps < 0 or training_step < max_steps:

# Restart tracer when needed (initial start or after completing a training step)
# Otherwise, we cannot measure time waiting for buffer
if restart_tracer:
Expand Down Expand Up @@ -471,6 +472,10 @@ async def continuous_training():
# Flush metrics every training step to WandB
await mlogger.flush.call_one(training_step)

print(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get the logger from src/forge/util/logging.py instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really use the logger elsewhere in the main script do we?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire main script is using print instead of logger right now.

f"Reached training limit ({max_steps} steps). Exiting continuous_training loop."
)

num_rollout_threads = cfg.get("rollout_threads", 1)
num_training_threads = cfg.get("training_threads", 1)
print(
Expand All @@ -482,14 +487,17 @@ async def continuous_training():
training_task = asyncio.create_task(continuous_training())

try:
await asyncio.gather(*rollout_tasks, training_task)
await training_task
except KeyboardInterrupt:
print("Training interrupted by user")
finally:
print("Shutting down...")

for rollout_task in rollout_tasks:
rollout_task.cancel()
# graceful await all tasks, ignore cancellation noise
await asyncio.gather(*rollout_tasks, return_exceptions=True)
training_task.cancel()
Comment on lines 515 to 537
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this change a bit? before we would do gather in the try, and they would be impacted by KeyboardInterrupt. Now the gather happens after KeyboardInterrupt. Would the user possibly have to run 'KeyboardInterrupt' twice?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this change a bit? before we would do gather in the try, and they would be impacted by KeyboardInterrupt. Now the gather happens after KeyboardInterrupt. Would the user possibly have to run 'KeyboardInterrupt' twice?

I think after the first KeyboardInterrupt, all the tasks are canceled and now we just gather on the canceled tasks (which should be fast to resolve) for graceful shutdown.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous gather(*rollout_tasks, training_task) call blocked on rollouts indefinitely, even after training_task completed (e.g., once max_steps was reached).
Since we now have a finite step limit that cleanly terminates training_task, we shouldn’t continue waiting on rollout tasks.

With this change, rollout termination is handled explicitly in the finally block.
This doesn’t theoretically change how KeyboardInterrupt is handled, since an interrupt caught in the except would still flow into finally.
However, we’ve already seen issues with KeyboardInterrupt handling not working properly (see #360 (2)); I’ll look into that separately in a follow-up PR.

finally:
print("Shutting down...")

# give mlogger time to shutdown backends, otherwise they can stay running.
# TODO (felipemello) find more elegant solution
Expand Down
Loading