Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,15 @@ def has_active_requests(self) -> bool:

def has_capacity(self) -> bool:
raise NotImplementedError("has_capacity is not implemented")

def health(self) -> bool:
"""Check the additional health status of the API.

This method is used in the /health endpoint of the server to determine the health status.
Users can extend this method to include additional health checks specific to their application.

Returns:
bool: True if the API is healthy, False otherwise.

"""
return True
3 changes: 2 additions & 1 deletion src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ async def health(request: Request) -> Response:
if not workers_ready:
workers_ready = all(v == WorkerSetupStatus.READY for v in self.workers_setup_status.values())

if workers_ready:
lit_api_health_status = self.lit_api.health()
if workers_ready and lit_api_health_status:
return Response(content="ok", status_code=200)

return Response(content="not ready", status_code=503)
Expand Down
44 changes: 44 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def setup(self, device):
time.sleep(2)


class SlowSetupWithCustomHealthLitAPI(SimpleLitAPI):
def setup(self, device):
self.model = lambda x: x**2
time.sleep(2)

def health(self) -> bool:
# Health check passes after 5 seconds from the first time it is called.
if not hasattr(self, "_start_time"):
self._start_time = time.time()
return time.time() - self._start_time >= 5


@pytest.mark.parametrize("use_zmq", [True, False])
def test_workers_health(use_zmq):
server = LitServer(
Expand Down Expand Up @@ -103,6 +115,38 @@ def test_workers_health_custom_path(use_zmq):
assert response.text == "ok"


@pytest.mark.parametrize("use_zmq", [True, False])
def test_workers_health_with_custom_health_method(use_zmq):
server = LitServer(
SlowSetupWithCustomHealthLitAPI(),
accelerator="cpu",
devices=1,
timeout=5,
workers_per_device=2,
fast_queue=use_zmq,
)

with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.get("/health")
assert response.status_code == 503
assert response.text == "not ready"

time.sleep(1)
response = client.get("/health")
assert response.status_code == 503
assert response.text == "not ready"

time.sleep(1)
response = client.get("/health")
assert response.status_code == 503
assert response.text == "not ready"

time.sleep(4)
response = client.get("/health")
assert response.status_code == 200
assert response.text == "ok"


def make_load_request(server, outputs):
with TestClient(server.app) as client:
for i in range(100):
Expand Down
Loading