Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,11 @@ async def main(cfg: DictConfig):
),
)

print("All services initialized successfully!")
# Set max_steps to the configured value, or -1 if not specified or Null
max_steps = cfg.trainer.training.steps or -1

print("All services initialized successfully!")
shutdown_event = asyncio.Event()
# In the HostMesh v1 case, we spawn a torchstore storage volume
# per trainer process.
# We initialize after service initialization because torchstore currently
Expand All @@ -381,7 +384,7 @@ async def main(cfg: DictConfig):
async def continuous_rollouts():
rollout_count = 0
pad_id = await dataloader.pad_token.call_one()
while True:
while not shutdown_event.is_set():
t = Tracer("main_perf/continuous_rollouts")
t.start()
sample = await dataloader.sample.call_one()
Expand Down Expand Up @@ -460,7 +463,7 @@ async def continuous_training():
training_step = 0
restart_tracer = True # Flag to control when to restart tracer

while True:
while max_steps == -1 or training_step < max_steps:
# Restart tracer when needed (initial start or after completing a training step)
# Otherwise, we cannot measure time waiting for buffer
if restart_tracer:
Expand Down Expand Up @@ -497,6 +500,10 @@ async def continuous_training():
# Flush metrics every training step to WandB
await mlogger.flush.call_one(training_step)

print(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get the logger from src/forge/util/logging.py instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really use the logger elsewhere in the main script do we?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire main script is using print instead of logger right now.

f"Reached training limit ({max_steps} steps). Exiting continuous_training loop."
)

num_rollout_threads = cfg.get("rollout_threads", 1)
num_training_threads = cfg.get("training_threads", 1)
print(
Expand All @@ -508,14 +515,26 @@ async def continuous_training():
training_task = asyncio.create_task(continuous_training())

try:
await asyncio.gather(*rollout_tasks, training_task)
await training_task
except KeyboardInterrupt:
print("Training interrupted by user")
for rollout_task in rollout_tasks:
rollout_task.cancel()
training_task.cancel()
finally:
print("Shutting down...")
shutdown_event.set()

try:
# Give rollouts up to 5s to finish naturally
await asyncio.wait_for(
asyncio.gather(*rollout_tasks, return_exceptions=True),
timeout=5,
)
except asyncio.TimeoutError:
print("Timeout waiting for rollouts; forcing cancellation...")
for t in rollout_tasks:
t.cancel()
await asyncio.gather(*rollout_tasks, return_exceptions=True)

training_task.cancel()

# give mlogger time to shutdown backends, otherwise they can stay running.
# TODO (felipemello) find more elegant solution
Expand Down
Loading