-
Notifications
You must be signed in to change notification settings - Fork 18
Training termination for grpo/main with finite step limit #376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
ef013c3
e966d5c
bb711fe
36cc2da
ea9d09c
3b61452
283f305
aeef847
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -349,6 +349,11 @@ async def main(cfg: DictConfig): | |
), | ||
) | ||
|
||
if "steps" not in cfg.trainer.training: | ||
raise ValueError("`cfg.trainer.training.steps` must be defined (can be null).") | ||
|
||
max_steps = cfg.trainer.training.steps | ||
|
||
print("All services initialized successfully!") | ||
|
||
# ---- Core RL loops ---- # | ||
|
@@ -434,7 +439,7 @@ async def continuous_training(): | |
training_step = 0 | ||
restart_tracer = True # Flag to control when to restart tracer | ||
|
||
while True: | ||
while max_steps is None or training_step < max_steps: | ||
|
||
# Restart tracer when needed (initial start or after completing a training step) | ||
# Otherwise, we cannot measure time waiting for buffer | ||
if restart_tracer: | ||
|
@@ -471,6 +476,10 @@ async def continuous_training(): | |
# Flush metrics every training step to WandB | ||
await mlogger.flush.call_one(training_step) | ||
|
||
print( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we get the logger from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't really use the logger elsewhere in the main script do we? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The entire main script is using print instead of logger right now. |
||
f"Reached training limit ({max_steps} steps). Exiting continuous_training loop." | ||
) | ||
|
||
num_rollout_threads = cfg.get("rollout_threads", 1) | ||
num_training_threads = cfg.get("training_threads", 1) | ||
print( | ||
|
@@ -482,14 +491,18 @@ async def continuous_training(): | |
training_task = asyncio.create_task(continuous_training()) | ||
|
||
try: | ||
await asyncio.gather(*rollout_tasks, training_task) | ||
await training_task | ||
except KeyboardInterrupt: | ||
print("Training interrupted by user") | ||
finally: | ||
print("Shutting down...") | ||
for rollout_task in rollout_tasks: | ||
rollout_task.cancel() | ||
# graceful await all tasks, ignore cancellation noise | ||
await asyncio.gather(*rollout_tasks, return_exceptions=True) | ||
# Give replicas time to drain and complete in-flight requests | ||
await asyncio.sleep(1) | ||
|
||
training_task.cancel() | ||
finally: | ||
print("Shutting down...") | ||
|
||
# give mlogger time to shutdown backends, otherwise they can stay running. | ||
# TODO (felipemello) find more elegant solution | ||
|
Uh oh!
There was an error while loading. Please reload this page.