Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,9 @@ async def main(cfg: DictConfig):
),
)

# Set max_steps to the configured value, or -1 if not specified or Null
max_steps = cfg.trainer.training.steps or -1

print("All services initialized successfully!")

# ---- Core RL loops ---- #
Expand Down Expand Up @@ -434,7 +437,7 @@ async def continuous_training():
training_step = 0
restart_tracer = True # Flag to control when to restart tracer

while True:
while max_steps == -1 or training_step < max_steps:
# Restart tracer when needed (initial start or after completing a training step)
# Otherwise, we cannot measure time waiting for buffer
if restart_tracer:
Expand Down Expand Up @@ -471,6 +474,10 @@ async def continuous_training():
# Flush metrics every training step to WandB
await mlogger.flush.call_one(training_step)

print(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get the logger from src/forge/util/logging.py instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really use the logger elsewhere in the main script do we?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire main script is using print instead of logger right now.

f"Reached training limit ({max_steps} steps). Exiting continuous_training loop."
)

num_rollout_threads = cfg.get("rollout_threads", 1)
num_training_threads = cfg.get("training_threads", 1)
print(
Expand All @@ -482,14 +489,18 @@ async def continuous_training():
training_task = asyncio.create_task(continuous_training())

try:
await asyncio.gather(*rollout_tasks, training_task)
await training_task
except KeyboardInterrupt:
print("Training interrupted by user")
finally:
print("Shutting down...")
for rollout_task in rollout_tasks:
rollout_task.cancel()
# graceful await all tasks, ignore cancellation noise
await asyncio.gather(*rollout_tasks, return_exceptions=True)
# Give replicas time to drain and complete in-flight requests
await asyncio.sleep(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can/should make this part more graceful. Proposal:

Continuous rollouts takes a shutdown event:

async def continuous_rollouts(shutdown_evnt: asyncio.Event):
    ...
    while not shutdown_event.is_set(): # no more while True
        ...

    print("Rollout loop got shutdown event, shutting down...")

then in our finally we can do a 2-phased shutdown:

finally:
    print("Shutting down...")
    shutdown_event.set()

    try:
        # give tasks a chance to exit gracefully
        await asyncio.wait_for(
            asyncio.gather(*rollout_tasks, return_exceptions=True),
            timeout=5
        )
    except asyncio.TimeoutError:
        print("Forcing cancellation...")
        for t in rollout_tasks:
            t.cancel()
        await asyncio.gather(*rollout_tasks, return_exceptions=True)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

training_task.cancel()
finally:
print("Shutting down...")

# give mlogger time to shutdown backends, otherwise they can stay running.
# TODO (felipemello) find more elegant solution
Expand Down
5 changes: 5 additions & 0 deletions src/forge/controller/service/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ async def _process_single_request(self, request: ServiceRequest) -> bool:
# can be healthy but the request failed.
self.mark_failed()
success = False
except asyncio.InvalidStateError:
# Future was already cancelled — safe to ignore during shutdown
logger.warning(f"Got invalid state error on replica {self.idx}.")
self.mark_failed()
success = False
except Exception as e:
logger.debug(f"Got unexpected error on replica {self.idx}. Error:\n{e}")
self.mark_failed()
Expand Down
Loading