Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions agentlightning/runner/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from agentlightning.litagent import LitAgent
from agentlightning.reward import emit_reward, find_final_reward
from agentlightning.store.base import LightningStore
from agentlightning.store.client_server import ServerShutdownError
from agentlightning.tracer.agentops import AgentOpsTracer
from agentlightning.tracer.base import Tracer
from agentlightning.types import (
Expand Down Expand Up @@ -289,7 +290,14 @@ async def _post_process_rollout_result(
# This will NOT emit another span to the tracer
reward_span = emit_reward(raw_result, propagate=False)
# We add it to the store manually
await store.add_otel_span(rollout.rollout_id, rollout.attempt.attempt_id, reward_span)
try:
await store.add_otel_span(rollout.rollout_id, rollout.attempt.attempt_id, reward_span)
except ServerShutdownError:
# Server is shutting down - handle gracefully without traceback
logger.debug(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this isn't the only place with such issue. We'd better handle it in store client.

f"{self._log_prefix(rollout.rollout_id)} Server is shutting down. "
"Skipping add_otel_span for reward span."
)
trace_spans.append(reward_span)

if isinstance(raw_result, list):
Expand All @@ -304,9 +312,16 @@ async def _post_process_rollout_result(
self._tracer, AgentOpsTracer
): # TODO: this should be replaced with general OpenTelemetry tracer in next version
for span in raw_result:
await store.add_otel_span(
rollout.rollout_id, rollout.attempt.attempt_id, cast(ReadableSpan, span)
)
try:
await store.add_otel_span(
rollout.rollout_id, rollout.attempt.attempt_id, cast(ReadableSpan, span)
)
except ServerShutdownError:
# Server is shutting down - handle gracefully without traceback
logger.debug(
f"{self._log_prefix(rollout.rollout_id)} Server is shutting down. "
f"Skipping add_otel_span for span: {span.name}"
)
else:
logger.warning(
f"{self._log_prefix(rollout.rollout_id)} Tracer is already an OpenTelemetry tracer. "
Expand Down Expand Up @@ -528,6 +543,9 @@ async def _step_impl(self, next_rollout: AttemptedRollout, raise_on_exception: b
await store.update_attempt(rollout_id, next_rollout.attempt.attempt_id, status="failed")
else:
await store.update_attempt(rollout_id, next_rollout.attempt.attempt_id, status="succeeded")
except ServerShutdownError:
# Server is shutting down - handle gracefully without traceback
logger.debug(f"{self._log_prefix(rollout_id)} Server is shutting down. " "Skipping update_attempt.")
except Exception:
logger.exception(
f"{self._log_prefix(rollout_id)} Exception during update_attempt. Giving up the update."
Expand Down Expand Up @@ -582,6 +600,12 @@ async def iter(self, *, event: Optional[ExecutionEvent] = None) -> None:
await store.update_attempt(
next_rollout.rollout_id, next_rollout.attempt.attempt_id, worker_id=self.get_worker_id()
)
except ServerShutdownError:
# Server is shutting down - handle gracefully without traceback
logger.debug(
f"{self._log_prefix()} Server is shutting down. " "Skipping update_attempt for rollout claim."
)
continue
except Exception:
# This exception could happen if the rollout is dequeued and the other end died for some reason
logger.exception(f"{self._log_prefix()} Exception during update_attempt, giving up the rollout.")
Expand Down
30 changes: 28 additions & 2 deletions agentlightning/store/client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@
T_model = TypeVar("T_model", bound=BaseModel)


class ServerShutdownError(Exception):
"""Raised when the server is shutting down and requests cannot be completed.

This exception is raised instead of ServerDisconnectedError when we detect
that the server is permanently unavailable (e.g., during graceful shutdown).
Callers should handle this gracefully without dumping full tracebacks.
"""


class RolloutRequest(BaseModel):
input: TaskInput
mode: Optional[Literal["train", "val", "test"]] = None
Expand Down Expand Up @@ -1238,6 +1247,9 @@ def __init__(
self._dequeue_was_successful: bool = False
self._dequeue_first_unsuccessful: bool = True

# Track server shutdown state to handle errors gracefully
self._server_shutting_down: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about having a _server_online shared flag and we can sunset hacky flags like _dequeue_first_unsuccessful.


@property
def capabilities(self) -> LightningStoreCapabilities:
"""Return the capabilities of the store."""
Expand Down Expand Up @@ -1287,6 +1299,7 @@ def __setstate__(self, state: Dict[str, Any]):
self._connection_timeout = state["_connection_timeout"]
self._dequeue_was_successful = False
self._dequeue_first_unsuccessful = True
self._server_shutting_down = False

async def _get_session(self) -> aiohttp.ClientSession:
# In the proxy process, FastAPI middleware calls
Expand Down Expand Up @@ -1324,6 +1337,7 @@ async def _wait_until_healthy(self, session: aiohttp.ClientSession) -> bool:
"""
Probe the server's /health until it responds 200 or retries are exhausted.
Returns True if healthy, False otherwise.
When this returns False, it indicates the server is shutting down or permanently unavailable.
"""
if not self._health_retry_delays:
client_logger.info("No health retry delays configured; skipping health checks.")
Expand All @@ -1342,9 +1356,12 @@ async def _wait_until_healthy(self, session: aiohttp.ClientSession) -> bool:
client_logger.warning(f"Server is not healthy yet. Retrying in {delay} seconds.")
if delay > 0.0:
await asyncio.sleep(delay)
client_logger.error(
f"Server is not healthy at {self.server_address}/health after {len(self._health_retry_delays)} retry attempts"
client_logger.warning(
f"Server is not healthy at {self.server_address}/health after {len(self._health_retry_delays)} retry attempts. "
"Server appears to be shutting down."
)
# Mark server as shutting down to handle subsequent errors gracefully
self._server_shutting_down = True
return False

async def _request_json(
Expand Down Expand Up @@ -1405,6 +1422,15 @@ async def _request_json(
last_exc = net_exc
client_logger.info(f"Network/session issue will be retried. Retrying the request {method}: {path}")
if not await self._wait_until_healthy(session):
# Server is shutting down - handle ServerDisconnectedError gracefully
if isinstance(net_exc, aiohttp.ServerDisconnectedError) and self._server_shutting_down:
client_logger.debug(
f"Server is shutting down. Suppressing ServerDisconnectedError for {method}: {path}"
)
# Raise a specific exception that callers can catch and handle gracefully
raise ServerShutdownError(
f"Server is shutting down. Request {method}: {path} cannot be completed."
) from net_exc
break # server is not healthy, do not retry

# exhausted retries
Expand Down