Skip to content

Commit f789fe8

Browse files
committed
feat: Add health check to API and client
Add health check endpoint that checks the health of the database, object store, and queue, returning a JSON status report. Leveraging this, add a `health_check` method to the client to verify service health, along with an "is_healthy" boolean method for convenience. Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 87e4327 commit f789fe8

File tree

8 files changed

+367
-2
lines changed

8 files changed

+367
-2
lines changed

client/cogstack_model_gateway_client/client.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,43 @@ async def deploy_model(
312312
resp = await self._request("POST", url, json=data)
313313
return resp.json()
314314

315+
@require_client
316+
async def health_check(self):
317+
"""Check if the Gateway and its components are healthy and responsive.
318+
319+
Returns:
320+
dict: Health status information with 'status' and 'components' fields.
321+
Status will be 'healthy' or 'unhealthy'.
322+
323+
Raises:
324+
httpx.HTTPStatusError: For HTTP errors other than 503 Service Unavailable.
325+
Other exceptions: For network errors, timeouts, etc.
326+
"""
327+
url = f"{self.base_url}/health"
328+
try:
329+
resp = await self._request("GET", url)
330+
return resp.json()
331+
except httpx.HTTPStatusError as e:
332+
if e.response.status_code == 503:
333+
try:
334+
return e.response.json()
335+
except Exception:
336+
return {"status": "unhealthy", "error": "Service unavailable"}
337+
raise
338+
339+
@require_client
340+
async def is_healthy(self):
341+
"""Check if the Gateway and its components are healthy.
342+
343+
Returns:
344+
bool: True if overall status is 'healthy', False otherwise.
345+
"""
346+
try:
347+
health_data = await self.health_check()
348+
return health_data.get("status") == "healthy"
349+
except Exception:
350+
return False
351+
315352

316353
class GatewayClientSync:
317354
"""A simplified synchronous wrapper around the async GatewayClient.
@@ -574,3 +611,24 @@ def deploy_model(
574611
ttl=ttl,
575612
)
576613
)
614+
615+
def health_check(self):
616+
"""Check if the Gateway and its components are healthy and responsive.
617+
618+
Returns:
619+
dict: Health status information with 'status' and 'components' fields.
620+
Status will be 'healthy' or 'unhealthy'.
621+
622+
Raises:
623+
httpx.HTTPStatusError: For HTTP errors other than 503 Service Unavailable.
624+
Other exceptions: For network errors, timeouts, etc.
625+
"""
626+
return asyncio.run(self._client.health_check())
627+
628+
def is_healthy(self):
629+
"""Check if the Gateway and its components are healthy.
630+
631+
Returns:
632+
bool: True if overall status is 'healthy', False otherwise.
633+
"""
634+
return asyncio.run(self._client.is_healthy())

cogstack_model_gateway/common/db.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,13 @@ def get_session(self):
4242
"""Get a database session."""
4343
with Session(self.engine) as session:
4444
yield session
45+
46+
def health_check(self) -> bool:
47+
"""Perform a health check by executing a simple query."""
48+
try:
49+
with self.get_session() as session:
50+
session.exec("SELECT 1")
51+
return True
52+
except Exception as e:
53+
log.error("Health check failed for database: %s", e)
54+
return False

cogstack_model_gateway/common/object_store.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,12 @@ def get_object_url(
8888
"""Get a presigned URL for the specified object."""
8989
bucket_name = bucket_name if bucket_name else self.default_bucket
9090
return self.client.presigned_get_object(bucket_name, object_key, expires=expires)
91+
92+
def health_check(self) -> bool:
93+
"""Perform a health check by listing objects in the configured default bucket."""
94+
try:
95+
self.client.list_objects(self.default_bucket, extra_query_params={"max-keys": "1"})
96+
return True
97+
except Exception as e:
98+
log.error("Health check failed for bucket '%s': %s", self.default_bucket, e)
99+
return False

cogstack_model_gateway/common/queue.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,11 @@ def is_queue_empty(self):
178178
return queue.method.message_count == 0
179179
except pika.exceptions.ChannelClosed:
180180
return True
181+
182+
@with_connection
183+
def health_check(self) -> bool:
184+
"""Perform a health check by verifying a connection can be established."""
185+
healthy = self.connection and not self.connection.is_closed
186+
if not healthy:
187+
log.error("Health check failed for queue '%s'", self.queue_name)
188+
return healthy

cogstack_model_gateway/gateway/main.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from contextlib import asynccontextmanager
44

55
import urllib3
6-
from fastapi import FastAPI
6+
from fastapi import FastAPI, HTTPException
77
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
88

9-
from cogstack_model_gateway.common.config import load_config
9+
from cogstack_model_gateway.common.config import get_config, load_config
1010
from cogstack_model_gateway.common.db import DatabaseManager
1111
from cogstack_model_gateway.common.logging import configure_logging
1212
from cogstack_model_gateway.common.object_store import ObjectStoreManager
@@ -99,3 +99,41 @@ async def prometheus_request_counter(request, call_next):
9999
async def root():
100100
"""Root endpoint for the gateway API."""
101101
return {"message": "Enter the cult... I mean, the API."}
102+
103+
104+
@app.get("/health")
105+
async def health_check():
106+
"""Health check endpoint that verifies the status of critical components."""
107+
try:
108+
config = get_config()
109+
110+
components_to_check = {
111+
"database": config.database_manager,
112+
"task_object_store": config.task_object_store_manager,
113+
"results_object_store": config.results_object_store_manager,
114+
"queue": config.queue_manager,
115+
}
116+
117+
component_status = {
118+
name: "healthy" if manager.health_check() else "unhealthy"
119+
for name, manager in components_to_check.items()
120+
}
121+
122+
overall_status = (
123+
"healthy"
124+
if all(status == "healthy" for status in component_status.values())
125+
else "unhealthy"
126+
)
127+
128+
health_status = {"status": overall_status, "components": component_status}
129+
130+
if overall_status == "unhealthy":
131+
raise HTTPException(status_code=503, detail=health_status)
132+
133+
return health_status
134+
135+
except Exception as e:
136+
raise HTTPException(
137+
status_code=503,
138+
detail={"status": "unhealthy", "error": f"Failed to perform health check: {str(e)}"},
139+
)

docker-compose.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ services:
128128
- gateway
129129
- cogstack-model-serve_cms
130130
- observability
131+
healthcheck:
132+
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
133+
interval: 60s
134+
timeout: 20s
135+
retries: 3
131136

132137
scheduler:
133138
image: cogstacksystems/cogstack-model-gateway-scheduler:${CMG_IMAGE_TAG:-latest}

tests/integration/test_api.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,24 @@ def test_root(client: TestClient):
9191
assert response.json() == {"message": "Enter the cult... I mean, the API."}
9292

9393

94+
def test_health_check(client: TestClient):
95+
response = client.get("/health")
96+
assert response.status_code == 200
97+
98+
health_data = response.json()
99+
assert "status" in health_data
100+
assert "components" in health_data
101+
102+
expected_components = ["database", "task_object_store", "results_object_store", "queue"]
103+
for component in expected_components:
104+
assert component in health_data["components"]
105+
106+
# In a properly configured test environment, all components should be healthy
107+
assert health_data["status"] == "healthy"
108+
for component in expected_components:
109+
assert health_data["components"][component] == "healthy"
110+
111+
94112
def test_get_tasks(client: TestClient):
95113
response = client.get("/tasks/")
96114
assert response.status_code == 403

0 commit comments

Comments
 (0)