-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Handle server shutdown gracefully to prevent traceback spam #408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2fdd43e
1156651
fb928f6
41948fa
397a34c
ec5b0e4
ca835ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -78,6 +78,15 @@ | |
| T_model = TypeVar("T_model", bound=BaseModel) | ||
|
|
||
|
|
||
| class ServerShutdownError(Exception): | ||
Vasuk12 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Raised when the server is shutting down and requests cannot be completed. | ||
|
|
||
| This exception is raised instead of ServerDisconnectedError when we detect | ||
| that the server is permanently unavailable (e.g., during graceful shutdown). | ||
| Callers should handle this gracefully without dumping full tracebacks. | ||
| """ | ||
|
|
||
|
|
||
| class RolloutRequest(BaseModel): | ||
| input: TaskInput | ||
| mode: Optional[Literal["train", "val", "test"]] = None | ||
|
|
@@ -252,6 +261,8 @@ def __init__( | |
| launch_mode: LaunchMode = "thread", | ||
| launcher_args: PythonServerLauncherArgs | None = None, | ||
| n_workers: int = 1, | ||
| prometheus: bool = False, | ||
| server_online_flag: Any = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should not be Any and it should be documented in docstring
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe you should document it as whatever the type of |
||
| tracker: MetricsBackend | None = None, | ||
| ): | ||
| super().__init__() | ||
|
|
@@ -301,12 +312,15 @@ def __init__( | |
| # LightningStoreServer holds a plain Python object (self.store) in one process | ||
| # (the process that runs uvicorn/FastAPI). | ||
| # When you multiprocessing.Process(...) and call methods on a different LightningStore instance | ||
| # (or on a copy inherited via fork), you’re mutating another process’s memory, not the server’s memory. | ||
| # (or on a copy inherited via fork), you're mutating another process's memory, not the server's memory. | ||
| # So we need to track the owner process (whoever creates the server), | ||
| # and only mutate the store in that process. | ||
| self._owner_pid = os.getpid() | ||
| self._client: Optional[LightningStoreClient] = None | ||
|
|
||
| # Set to True after server launches, False before server stops | ||
| self._server_online_flag = server_online_flag | ||
|
|
||
| @property | ||
| def capabilities(self) -> LightningStoreCapabilities: | ||
| """Return the capabilities of the store.""" | ||
|
|
@@ -416,6 +430,11 @@ async def start(self): | |
| end_time = time.time() | ||
| server_logger.info(f"Lightning store server started in {end_time - start_time:.2f} seconds") | ||
|
|
||
| # Set server online flag to True after server has launched | ||
| if self._server_online_flag is not None: | ||
| with self._server_online_flag.get_lock(): | ||
| self._server_online_flag.value = True | ||
|
|
||
| async def run_forever(self): | ||
| """Runs the FastAPI server indefinitely.""" | ||
| server_logger.info( | ||
|
|
@@ -428,6 +447,11 @@ async def stop(self): | |
|
|
||
| You need to call this method in the same process as the server was created in. | ||
| """ | ||
| # Set server online flag to False before server stops | ||
| if self._server_online_flag is not None: | ||
| with self._server_online_flag.get_lock(): | ||
| self._server_online_flag.value = False | ||
|
|
||
| server_logger.info("Stopping the lightning store server...") | ||
| await self.server_launcher.stop() | ||
| server_logger.info("Lightning store server stopped.") | ||
|
|
@@ -1355,6 +1379,7 @@ def __init__( | |
| health_retry_delays: Sequence[float] = (0.1, 0.2, 0.5), | ||
| request_timeout: float = 30.0, | ||
| connection_timeout: float = 5.0, | ||
| server_online_flag: Any = None, | ||
| ): | ||
| self.server_address_root = server_address.rstrip("/") | ||
| self.server_address = self.server_address_root + API_V1_AGL_PREFIX | ||
|
|
@@ -1371,7 +1396,9 @@ def __init__( | |
|
|
||
| # Store whether the dequeue was successful in history | ||
| self._dequeue_was_successful: bool = False | ||
| self._dequeue_first_unsuccessful: bool = True | ||
|
|
||
| # When requests fail, check this flag - if server is not online, silently ignore errors | ||
| self._server_online_flag = server_online_flag | ||
|
|
||
| @property | ||
| def capabilities(self) -> LightningStoreCapabilities: | ||
|
|
@@ -1421,7 +1448,6 @@ def __setstate__(self, state: Dict[str, Any]): | |
| self._request_timeout = state["_request_timeout"] | ||
| self._connection_timeout = state["_connection_timeout"] | ||
| self._dequeue_was_successful = False | ||
| self._dequeue_first_unsuccessful = True | ||
|
|
||
| async def _get_session(self) -> aiohttp.ClientSession: | ||
| # In the proxy process, FastAPI middleware calls | ||
|
|
@@ -1459,6 +1485,7 @@ async def _wait_until_healthy(self, session: aiohttp.ClientSession) -> bool: | |
| """ | ||
| Probe the server's /health until it responds 200 or retries are exhausted. | ||
| Returns True if healthy, False otherwise. | ||
| When this returns False, it indicates the server is shutting down or permanently unavailable. | ||
| """ | ||
| if not self._health_retry_delays: | ||
| client_logger.info("No health retry delays configured; skipping health checks.") | ||
|
|
@@ -1477,8 +1504,9 @@ async def _wait_until_healthy(self, session: aiohttp.ClientSession) -> bool: | |
| client_logger.warning(f"Server is not healthy yet. Retrying in {delay} seconds.") | ||
| if delay > 0.0: | ||
| await asyncio.sleep(delay) | ||
| client_logger.error( | ||
| f"Server is not healthy at {self.server_address}/health after {len(self._health_retry_delays)} retry attempts" | ||
| client_logger.warning( | ||
| f"Server is not healthy at {self.server_address}/health after {len(self._health_retry_delays)} retry attempts. " | ||
| "Server appears to be shutting down." | ||
| ) | ||
| return False | ||
|
|
||
|
|
@@ -1540,10 +1568,54 @@ async def _request_json( | |
| last_exc = net_exc | ||
| client_logger.info(f"Network/session issue: {net_exc} - will retry the request {method}: {path}") | ||
| if not await self._wait_until_healthy(session): | ||
| # Check shared flag - if server is not online, silently ignore error | ||
| if self._server_online_flag is not None: | ||
| with self._server_online_flag.get_lock(): | ||
| is_online = bool(self._server_online_flag.value) | ||
| if not is_online: | ||
| client_logger.debug( | ||
| f"Server is not online (shared flag). Silently ignoring {type(net_exc).__name__} for {method}: {path}" | ||
| ) | ||
| # Silently ignore - return None to indicate failure was expected | ||
| return None | ||
| break # server is not healthy, do not retry | ||
| except asyncio.CancelledError as cancel_exc: | ||
| # Cancellation can occur during async operations, especially during shutdown | ||
| client_logger.debug(f"Request cancelled: {method}: {path}", exc_info=True) | ||
| # Check shared flag - if server is not online, silently ignore error | ||
| if self._server_online_flag is not None: | ||
| with self._server_online_flag.get_lock(): | ||
| is_online = bool(self._server_online_flag.value) | ||
| if not is_online: | ||
| client_logger.debug( | ||
| f"Server is not online (shared flag). Silently ignoring CancelledError for {method}: {path}" | ||
| ) | ||
| # Silently ignore - return None to indicate failure was expected | ||
| return None | ||
Vasuk12 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # If flag not available or server is online, re-raise cancellation | ||
| raise cancel_exc | ||
|
|
||
| # exhausted retries | ||
| assert last_exc is not None | ||
| # Before raising, check shared flag - if server is not online, silently ignore error | ||
| if isinstance( | ||
| last_exc, | ||
| ( | ||
| aiohttp.ServerDisconnectedError, | ||
| aiohttp.ClientConnectorError, | ||
| aiohttp.ClientOSError, | ||
| asyncio.TimeoutError, | ||
| ), | ||
| ): | ||
| if self._server_online_flag is not None: | ||
| with self._server_online_flag.get_lock(): | ||
| is_online = bool(self._server_online_flag.value) | ||
| if not is_online: | ||
| client_logger.debug( | ||
| f"Server is not online (shared flag). Silently ignoring {type(last_exc).__name__} for {method}: {path}" | ||
| ) | ||
| # Silently ignore - return None to indicate failure was expected | ||
| return None | ||
| raise last_exc | ||
|
|
||
| async def close(self): | ||
|
|
@@ -1649,10 +1721,18 @@ async def _dequeue_batch( | |
| self._dequeue_was_successful = True | ||
| return [AttemptedRollout.model_validate(item) for item in data] | ||
| except Exception as e: | ||
| # Check shared flag - if server is not online, silently ignore error | ||
| if self._server_online_flag is not None: | ||
| with self._server_online_flag.get_lock(): | ||
| is_online = bool(self._server_online_flag.value) | ||
| if not is_online: | ||
| client_logger.debug( | ||
| f"Server is not online (shared flag). Silently ignoring dequeue_rollout failure: {e}" | ||
| ) | ||
| return None | ||
| # Log warning if server was online and dequeue was successful before (transition from online to offline) | ||
| if self._dequeue_was_successful: | ||
Vasuk12 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if self._dequeue_first_unsuccessful: | ||
| client_logger.warning(f"dequeue_rollout failed with exception: {e}") | ||
| self._dequeue_first_unsuccessful = False | ||
| client_logger.warning(f"dequeue_rollout failed with exception: {e}") | ||
| client_logger.debug("dequeue_rollout failed with exception. Details:", exc_info=True) | ||
| # Else ignore the exception because the server is not ready yet | ||
| return [] | ||
|
|
@@ -1916,16 +1996,23 @@ async def add_otel_span( | |
| readable_span: ReadableSpan, | ||
| sequence_id: int | None = None, | ||
| ) -> Optional[Span]: | ||
| # unchanged logic, now benefits from retries inside add_span/get_next_span_sequence_id | ||
| if sequence_id is None: | ||
| sequence_id = await self.get_next_span_sequence_id(rollout_id, attempt_id) | ||
| span = Span.from_opentelemetry( | ||
| readable_span, | ||
| rollout_id=rollout_id, | ||
| attempt_id=attempt_id, | ||
| sequence_id=sequence_id, | ||
| ) | ||
| return await self.add_span(span) | ||
| try: | ||
| # unchanged logic, now benefits from retries inside add_span/get_next_span_sequence_id | ||
| if sequence_id is None: | ||
| sequence_id = await self.get_next_span_sequence_id(rollout_id, attempt_id) | ||
| span = Span.from_opentelemetry( | ||
| readable_span, | ||
| rollout_id=rollout_id, | ||
| attempt_id=attempt_id, | ||
| sequence_id=sequence_id, | ||
| ) | ||
| return await self.add_span(span) | ||
| except (ServerShutdownError, asyncio.CancelledError): | ||
| # Server is shutting down or request was cancelled - handle gracefully without traceback | ||
| client_logger.debug( | ||
| f"Server is shutting down or request cancelled. Skipping add_otel_span for rollout {rollout_id}, attempt {attempt_id}." | ||
| ) | ||
| return None | ||
|
|
||
| async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[Rollout]: | ||
| """Wait for rollouts to complete. | ||
|
|
@@ -2030,22 +2117,29 @@ async def update_attempt( | |
| last_heartbeat_time: float | Unset = UNSET, | ||
| metadata: Optional[Dict[str, Any]] | Unset = UNSET, | ||
| ) -> Attempt: | ||
| payload: Dict[str, Any] = {} | ||
| if not isinstance(status, Unset): | ||
| payload["status"] = status | ||
| if not isinstance(worker_id, Unset): | ||
| payload["worker_id"] = worker_id | ||
| if not isinstance(last_heartbeat_time, Unset): | ||
| payload["last_heartbeat_time"] = last_heartbeat_time | ||
| if not isinstance(metadata, Unset): | ||
| payload["metadata"] = metadata | ||
|
|
||
| data = await self._request_json( | ||
| "post", | ||
| f"/rollouts/{rollout_id}/attempts/{attempt_id}", | ||
| json=payload, | ||
| ) | ||
| return Attempt.model_validate(data) | ||
| try: | ||
| payload: Dict[str, Any] = {} | ||
| if not isinstance(status, Unset): | ||
| payload["status"] = status | ||
| if not isinstance(worker_id, Unset): | ||
| payload["worker_id"] = worker_id | ||
| if not isinstance(last_heartbeat_time, Unset): | ||
| payload["last_heartbeat_time"] = last_heartbeat_time | ||
| if not isinstance(metadata, Unset): | ||
| payload["metadata"] = metadata | ||
|
|
||
| data = await self._request_json( | ||
| "post", | ||
| f"/rollouts/{rollout_id}/attempts/{attempt_id}", | ||
| json=payload, | ||
| ) | ||
| return Attempt.model_validate(data) | ||
| except (ServerShutdownError, asyncio.CancelledError): | ||
| # Server is shutting down or request was cancelled - handle gracefully without traceback | ||
| client_logger.debug( | ||
| f"Server is shutting down or request cancelled. Skipping update_attempt for rollout {rollout_id}, attempt {attempt_id}." | ||
| ) | ||
| raise | ||
|
|
||
| async def query_workers( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add unit-tests for this feature!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in which file do u want the unit tests? or should i create a new one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in tests/execution/test_client_server.py