diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 18b82c3e9..6439ead85 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -359,8 +359,11 @@ async def main(cfg: DictConfig): ), ) - print("All services initialized successfully!") + # Set max_steps to the configured value, or -1 if not specified or Null + max_steps = cfg.trainer.training.steps or -1 + print("All services initialized successfully!") + shutdown_event = asyncio.Event() # In the HostMesh v1 case, we spawn a torchstore storage volume # per trainer process. # We initialize after service initialization because torchstore currently @@ -381,7 +384,7 @@ async def main(cfg: DictConfig): async def continuous_rollouts(): rollout_count = 0 pad_id = await dataloader.pad_token.call_one() - while True: + while not shutdown_event.is_set(): t = Tracer("main_perf/continuous_rollouts") t.start() sample = await dataloader.sample.call_one() @@ -460,7 +463,7 @@ async def continuous_training(): training_step = 0 restart_tracer = True # Flag to control when to restart tracer - while True: + while max_steps == -1 or training_step < max_steps: # Restart tracer when needed (initial start or after completing a training step) # Otherwise, we cannot measure time waiting for buffer if restart_tracer: @@ -497,6 +500,10 @@ async def continuous_training(): # Flush metrics every training step to WandB await mlogger.flush.call_one(training_step) + print( + f"Reached training limit ({max_steps} steps). Exiting continuous_training loop." + ) + num_rollout_threads = cfg.get("rollout_threads", 1) num_training_threads = cfg.get("training_threads", 1) print( @@ -508,14 +515,26 @@ async def continuous_training(): training_task = asyncio.create_task(continuous_training()) try: - await asyncio.gather(*rollout_tasks, training_task) + await training_task except KeyboardInterrupt: print("Training interrupted by user") - for rollout_task in rollout_tasks: - rollout_task.cancel() - training_task.cancel() finally: print("Shutting down...") + shutdown_event.set() + + try: + # Give rollouts up to 5s to finish naturally + await asyncio.wait_for( + asyncio.gather(*rollout_tasks, return_exceptions=True), + timeout=5, + ) + except asyncio.TimeoutError: + print("Timeout waiting for rollouts; forcing cancellation...") + for t in rollout_tasks: + t.cancel() + await asyncio.gather(*rollout_tasks, return_exceptions=True) + + training_task.cancel() # give mlogger time to shutdown backends, otherwise they can stay running. # TODO (felipemello) find more elegant solution