Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions agentlightning/execution/client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,21 @@ def __init__(
)
self.allowed_exit_codes = tuple(allowed_exit_codes)

# This flag is set to True after server launches and False before server stops
# Clients check this flag when requests fail - if server is not online, silently ignore errors
# Be mindful of performance: all processes need to synchronously read this flag
ctx = multiprocessing.get_context()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add unit-tests for this feature!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in which file do u want the unit tests? or should i create a new one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in tests/execution/test_client_server.py

self._server_online = ctx.Value("b", False) # 'b' = signed char, False = 0

async def _execute_algorithm(
self, algorithm: AlgorithmBundle, store: LightningStore, stop_evt: ExecutionEvent
) -> None:
wrapper_store: LightningStore | None = None
if self.managed_store:
logger.info("Starting LightningStore server on %s:%s", self.server_host, self.server_port)
wrapper_store = LightningStoreServer(store, host=self.server_host, port=self.server_port)
wrapper_store = LightningStoreServer(
store, host=self.server_host, port=self.server_port, server_online_flag=self._server_online
)
server_started = False
else:
wrapper_store = store
Expand Down Expand Up @@ -173,7 +181,9 @@ async def _execute_runner(
) -> None:
if self.managed_store:
# If managed, we actually do not use the provided store
client_store = LightningStoreClient(f"http://{self.server_host}:{self.server_port}")
client_store = LightningStoreClient(
f"http://{self.server_host}:{self.server_port}", server_online_flag=self._server_online
)
else:
client_store = store
try:
Expand Down
162 changes: 128 additions & 34 deletions agentlightning/store/client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@
T_model = TypeVar("T_model", bound=BaseModel)


class ServerShutdownError(Exception):
"""Raised when the server is shutting down and requests cannot be completed.

This exception is raised instead of ServerDisconnectedError when we detect
that the server is permanently unavailable (e.g., during graceful shutdown).
Callers should handle this gracefully without dumping full tracebacks.
"""


class RolloutRequest(BaseModel):
input: TaskInput
mode: Optional[Literal["train", "val", "test"]] = None
Expand Down Expand Up @@ -252,6 +261,8 @@ def __init__(
launch_mode: LaunchMode = "thread",
launcher_args: PythonServerLauncherArgs | None = None,
n_workers: int = 1,
prometheus: bool = False,
server_online_flag: Any = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be Any and it should be documented in docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed Any and documented server_online_flag in the docstring. But with object | None typing, we get linter errors accessing .get_lock() and .value (from multiprocessing.synchronize.Synchronized). Should we use a Protocol import with TYPE_CHECKING or just use type ignores?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you should document it as whatever the type of ctx.Value("b", False) is. Use copilot/cc/codex to help you if you are not sure.

tracker: MetricsBackend | None = None,
):
super().__init__()
Expand Down Expand Up @@ -301,12 +312,15 @@ def __init__(
# LightningStoreServer holds a plain Python object (self.store) in one process
# (the process that runs uvicorn/FastAPI).
# When you multiprocessing.Process(...) and call methods on a different LightningStore instance
# (or on a copy inherited via fork), youre mutating another processs memory, not the servers memory.
# (or on a copy inherited via fork), you're mutating another process's memory, not the server's memory.
# So we need to track the owner process (whoever creates the server),
# and only mutate the store in that process.
self._owner_pid = os.getpid()
self._client: Optional[LightningStoreClient] = None

# Set to True after server launches, False before server stops
self._server_online_flag = server_online_flag

@property
def capabilities(self) -> LightningStoreCapabilities:
"""Return the capabilities of the store."""
Expand Down Expand Up @@ -416,6 +430,11 @@ async def start(self):
end_time = time.time()
server_logger.info(f"Lightning store server started in {end_time - start_time:.2f} seconds")

# Set server online flag to True after server has launched
if self._server_online_flag is not None:
with self._server_online_flag.get_lock():
self._server_online_flag.value = True

async def run_forever(self):
"""Runs the FastAPI server indefinitely."""
server_logger.info(
Expand All @@ -428,6 +447,11 @@ async def stop(self):

You need to call this method in the same process as the server was created in.
"""
# Set server online flag to False before server stops
if self._server_online_flag is not None:
with self._server_online_flag.get_lock():
self._server_online_flag.value = False

server_logger.info("Stopping the lightning store server...")
await self.server_launcher.stop()
server_logger.info("Lightning store server stopped.")
Expand Down Expand Up @@ -1355,6 +1379,7 @@ def __init__(
health_retry_delays: Sequence[float] = (0.1, 0.2, 0.5),
request_timeout: float = 30.0,
connection_timeout: float = 5.0,
server_online_flag: Any = None,
):
self.server_address_root = server_address.rstrip("/")
self.server_address = self.server_address_root + API_V1_AGL_PREFIX
Expand All @@ -1371,7 +1396,9 @@ def __init__(

# Store whether the dequeue was successful in history
self._dequeue_was_successful: bool = False
self._dequeue_first_unsuccessful: bool = True

# When requests fail, check this flag - if server is not online, silently ignore errors
self._server_online_flag = server_online_flag

@property
def capabilities(self) -> LightningStoreCapabilities:
Expand Down Expand Up @@ -1421,7 +1448,6 @@ def __setstate__(self, state: Dict[str, Any]):
self._request_timeout = state["_request_timeout"]
self._connection_timeout = state["_connection_timeout"]
self._dequeue_was_successful = False
self._dequeue_first_unsuccessful = True

async def _get_session(self) -> aiohttp.ClientSession:
# In the proxy process, FastAPI middleware calls
Expand Down Expand Up @@ -1459,6 +1485,7 @@ async def _wait_until_healthy(self, session: aiohttp.ClientSession) -> bool:
"""
Probe the server's /health until it responds 200 or retries are exhausted.
Returns True if healthy, False otherwise.
When this returns False, it indicates the server is shutting down or permanently unavailable.
"""
if not self._health_retry_delays:
client_logger.info("No health retry delays configured; skipping health checks.")
Expand All @@ -1477,8 +1504,9 @@ async def _wait_until_healthy(self, session: aiohttp.ClientSession) -> bool:
client_logger.warning(f"Server is not healthy yet. Retrying in {delay} seconds.")
if delay > 0.0:
await asyncio.sleep(delay)
client_logger.error(
f"Server is not healthy at {self.server_address}/health after {len(self._health_retry_delays)} retry attempts"
client_logger.warning(
f"Server is not healthy at {self.server_address}/health after {len(self._health_retry_delays)} retry attempts. "
"Server appears to be shutting down."
)
return False

Expand Down Expand Up @@ -1540,10 +1568,54 @@ async def _request_json(
last_exc = net_exc
client_logger.info(f"Network/session issue: {net_exc} - will retry the request {method}: {path}")
if not await self._wait_until_healthy(session):
# Check shared flag - if server is not online, silently ignore error
if self._server_online_flag is not None:
with self._server_online_flag.get_lock():
is_online = bool(self._server_online_flag.value)
if not is_online:
client_logger.debug(
f"Server is not online (shared flag). Silently ignoring {type(net_exc).__name__} for {method}: {path}"
)
# Silently ignore - return None to indicate failure was expected
return None
break # server is not healthy, do not retry
except asyncio.CancelledError as cancel_exc:
# Cancellation can occur during async operations, especially during shutdown
client_logger.debug(f"Request cancelled: {method}: {path}", exc_info=True)
# Check shared flag - if server is not online, silently ignore error
if self._server_online_flag is not None:
with self._server_online_flag.get_lock():
is_online = bool(self._server_online_flag.value)
if not is_online:
client_logger.debug(
f"Server is not online (shared flag). Silently ignoring CancelledError for {method}: {path}"
)
# Silently ignore - return None to indicate failure was expected
return None
# If flag not available or server is online, re-raise cancellation
raise cancel_exc

# exhausted retries
assert last_exc is not None
# Before raising, check shared flag - if server is not online, silently ignore error
if isinstance(
last_exc,
(
aiohttp.ServerDisconnectedError,
aiohttp.ClientConnectorError,
aiohttp.ClientOSError,
asyncio.TimeoutError,
),
):
if self._server_online_flag is not None:
with self._server_online_flag.get_lock():
is_online = bool(self._server_online_flag.value)
if not is_online:
client_logger.debug(
f"Server is not online (shared flag). Silently ignoring {type(last_exc).__name__} for {method}: {path}"
)
# Silently ignore - return None to indicate failure was expected
return None
raise last_exc

async def close(self):
Expand Down Expand Up @@ -1649,10 +1721,18 @@ async def _dequeue_batch(
self._dequeue_was_successful = True
return [AttemptedRollout.model_validate(item) for item in data]
except Exception as e:
# Check shared flag - if server is not online, silently ignore error
if self._server_online_flag is not None:
with self._server_online_flag.get_lock():
is_online = bool(self._server_online_flag.value)
if not is_online:
client_logger.debug(
f"Server is not online (shared flag). Silently ignoring dequeue_rollout failure: {e}"
)
return None
# Log warning if server was online and dequeue was successful before (transition from online to offline)
if self._dequeue_was_successful:
if self._dequeue_first_unsuccessful:
client_logger.warning(f"dequeue_rollout failed with exception: {e}")
self._dequeue_first_unsuccessful = False
client_logger.warning(f"dequeue_rollout failed with exception: {e}")
client_logger.debug("dequeue_rollout failed with exception. Details:", exc_info=True)
# Else ignore the exception because the server is not ready yet
return []
Expand Down Expand Up @@ -1916,16 +1996,23 @@ async def add_otel_span(
readable_span: ReadableSpan,
sequence_id: int | None = None,
) -> Optional[Span]:
# unchanged logic, now benefits from retries inside add_span/get_next_span_sequence_id
if sequence_id is None:
sequence_id = await self.get_next_span_sequence_id(rollout_id, attempt_id)
span = Span.from_opentelemetry(
readable_span,
rollout_id=rollout_id,
attempt_id=attempt_id,
sequence_id=sequence_id,
)
return await self.add_span(span)
try:
# unchanged logic, now benefits from retries inside add_span/get_next_span_sequence_id
if sequence_id is None:
sequence_id = await self.get_next_span_sequence_id(rollout_id, attempt_id)
span = Span.from_opentelemetry(
readable_span,
rollout_id=rollout_id,
attempt_id=attempt_id,
sequence_id=sequence_id,
)
return await self.add_span(span)
except (ServerShutdownError, asyncio.CancelledError):
# Server is shutting down or request was cancelled - handle gracefully without traceback
client_logger.debug(
f"Server is shutting down or request cancelled. Skipping add_otel_span for rollout {rollout_id}, attempt {attempt_id}."
)
return None

async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[Rollout]:
"""Wait for rollouts to complete.
Expand Down Expand Up @@ -2030,22 +2117,29 @@ async def update_attempt(
last_heartbeat_time: float | Unset = UNSET,
metadata: Optional[Dict[str, Any]] | Unset = UNSET,
) -> Attempt:
payload: Dict[str, Any] = {}
if not isinstance(status, Unset):
payload["status"] = status
if not isinstance(worker_id, Unset):
payload["worker_id"] = worker_id
if not isinstance(last_heartbeat_time, Unset):
payload["last_heartbeat_time"] = last_heartbeat_time
if not isinstance(metadata, Unset):
payload["metadata"] = metadata

data = await self._request_json(
"post",
f"/rollouts/{rollout_id}/attempts/{attempt_id}",
json=payload,
)
return Attempt.model_validate(data)
try:
payload: Dict[str, Any] = {}
if not isinstance(status, Unset):
payload["status"] = status
if not isinstance(worker_id, Unset):
payload["worker_id"] = worker_id
if not isinstance(last_heartbeat_time, Unset):
payload["last_heartbeat_time"] = last_heartbeat_time
if not isinstance(metadata, Unset):
payload["metadata"] = metadata

data = await self._request_json(
"post",
f"/rollouts/{rollout_id}/attempts/{attempt_id}",
json=payload,
)
return Attempt.model_validate(data)
except (ServerShutdownError, asyncio.CancelledError):
# Server is shutting down or request was cancelled - handle gracefully without traceback
client_logger.debug(
f"Server is shutting down or request cancelled. Skipping update_attempt for rollout {rollout_id}, attempt {attempt_id}."
)
raise

async def query_workers(
self,
Expand Down
Loading