- 
                Notifications
    You must be signed in to change notification settings 
- Fork 17
Training termination for grpo/main with finite step limit #376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
ef013c3
              e966d5c
              bb711fe
              36cc2da
              ea9d09c
              3b61452
              283f305
              aeef847
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -349,6 +349,9 @@ async def main(cfg: DictConfig): | |
| ), | ||
| ) | ||
|  | ||
| # Set max_steps to the configured value, or -1 if not specified or Null | ||
| max_steps = cfg.trainer.training.steps or -1 | ||
|  | ||
| print("All services initialized successfully!") | ||
|  | ||
| # ---- Core RL loops ---- # | ||
|  | @@ -434,7 +437,7 @@ async def continuous_training(): | |
| training_step = 0 | ||
| restart_tracer = True # Flag to control when to restart tracer | ||
|  | ||
| while True: | ||
| while max_steps == -1 or training_step < max_steps: | ||
| # Restart tracer when needed (initial start or after completing a training step) | ||
| # Otherwise, we cannot measure time waiting for buffer | ||
| if restart_tracer: | ||
|  | @@ -471,6 +474,10 @@ async def continuous_training(): | |
| # Flush metrics every training step to WandB | ||
| await mlogger.flush.call_one(training_step) | ||
|  | ||
| print( | ||
| f"Reached training limit ({max_steps} steps). Exiting continuous_training loop." | ||
| ) | ||
|  | ||
| num_rollout_threads = cfg.get("rollout_threads", 1) | ||
| num_training_threads = cfg.get("training_threads", 1) | ||
| print( | ||
|  | @@ -482,14 +489,18 @@ async def continuous_training(): | |
| training_task = asyncio.create_task(continuous_training()) | ||
|  | ||
| try: | ||
| await asyncio.gather(*rollout_tasks, training_task) | ||
| await training_task | ||
| except KeyboardInterrupt: | ||
| print("Training interrupted by user") | ||
| finally: | ||
| print("Shutting down...") | ||
| for rollout_task in rollout_tasks: | ||
| rollout_task.cancel() | ||
| # graceful await all tasks, ignore cancellation noise | ||
| await asyncio.gather(*rollout_tasks, return_exceptions=True) | ||
| # Give replicas time to drain and complete in-flight requests | ||
| await asyncio.sleep(1) | ||
|          | ||
| training_task.cancel() | ||
| finally: | ||
| print("Shutting down...") | ||
|  | ||
| # give mlogger time to shutdown backends, otherwise they can stay running. | ||
| # TODO (felipemello) find more elegant solution | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we get the logger from
src/forge/util/logging.pyinstead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't really use the logger elsewhere in the main script do we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The entire main script is using print instead of logger right now.