Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions openrag/api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import os
import time
import warnings
from enum import Enum
from importlib.metadata import version as get_package_version
from pathlib import Path

import httpx
import ray
import uvicorn
from config import load_config
Expand Down Expand Up @@ -196,10 +199,85 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
app.mount("/static", StaticFiles(directory=DATA_DIR.resolve(), check_dir=True), name="static")


async def check_service_health(base_url: str, service_name: str) -> dict:
"""
Probe a service health endpoint with timeout.

Args:
base_url: Base URL of the service (e.g., "http://localhost:8000")
service_name: Human-readable name for logging

Returns:
dict with status (healthy/unhealthy/timeout/unreachable/error),
response_time_ms, and error message if applicable
"""
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(3.0)) as client:
response = await client.get(f"{base_url}/health")
elapsed_ms = response.elapsed.total_seconds() * 1000

if response.status_code == 200:
return {"status": "healthy", "response_time_ms": round(elapsed_ms, 2)}
else:
return {
"status": "unhealthy",
"error": f"HTTP {response.status_code}",
"response_time_ms": round(elapsed_ms, 2),
}
except httpx.TimeoutException:
return {"status": "timeout", "error": "Service did not respond within 3s"}
except httpx.ConnectError:
return {"status": "unreachable", "error": "Connection refused"}
except Exception as e:
return {"status": "error", "error": str(e)}


@app.get("/health_check", summary="Health check endpoint for API", dependencies=[])
async def health_check(request: Request):
# TODO : Error reporting about llm and vlm
return "RAG API is up."
"""
Health check endpoint with LLM and VLM service probes.

Returns HTTP 200 for healthy/degraded, HTTP 503 for unhealthy.
LLM is critical, VLM is non-critical (used only for image captioning).
"""
config = request.app.state.app_state.config

# Probe LLM and VLM services concurrently
# Strip API path (e.g. /v1/) to get the service root for health probes
llm_base_url = config.llm.get("base_url", "").split("/v1")[0]
vlm_base_url = config.vlm.get("base_url", "").split("/v1")[0]

results = await asyncio.gather(
check_service_health(llm_base_url, "llm"), check_service_health(vlm_base_url, "vlm"), return_exceptions=True
)

# Handle gather results (defensive: check if any result is an Exception)
llm_result = results[0] if not isinstance(results[0], Exception) else {"status": "error", "error": str(results[0])}
vlm_result = results[1] if not isinstance(results[1], Exception) else {"status": "error", "error": str(results[1])}

# Determine overall status
llm_healthy = llm_result.get("status") == "healthy"
vlm_healthy = vlm_result.get("status") == "healthy"

if llm_healthy and vlm_healthy:
overall_status = "healthy"
status_code = 200
elif llm_healthy and not vlm_healthy:
# VLM is non-critical (only used for image captioning)
overall_status = "degraded"
status_code = 200
else:
# LLM is critical - any LLM failure is unhealthy
overall_status = "unhealthy"
status_code = 503

response_data = {
"status": overall_status,
"checks": {"api": {"status": "healthy"}, "llm": llm_result, "vlm": vlm_result},
"timestamp": time.time(),
}

return JSONResponse(status_code=status_code, content=response_data)


@app.get("/version", summary="Get openRAG version", dependencies=[])
Expand Down
100 changes: 92 additions & 8 deletions openrag/scripts/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def read_rdb_section(
added_documents: dict[str, set[str]],
existing_partitions: dict[str, Any],
logger: Any,
restore_state: dict[str, Any],
user_id: int,
verbose: bool = False,
dry_run: bool = False,
Expand Down Expand Up @@ -69,19 +70,44 @@ def read_rdb_section(
try:
res = pfm.add_file_to_partition(doc["file_id"], part["name"], doc, user_id)
except Exception as e:
logger.exception(
f"{type(e)} in add_file_to_partition({doc['file_id']}, {part['name']}, ...)\n" + str(e)
)
raise
# Non-critical failure: log and continue instead of raising
logger.bind(
file_id=doc["file_id"],
partition=part["name"],
error_type=type(e).__name__,
).error(f"Failed to add file to partition: {str(e)}")
restore_state["files_failed"] += 1
if len(restore_state["errors"]) < 100:
restore_state["errors"].append(
{
"file_id": doc["file_id"],
"partition": part["name"],
"error": str(e),
}
)
res = False
else:
res = True

if res:
if part["name"] not in added_documents:
added_documents[part["name"]] = set()
# Track partition creation (first file added successfully)
restore_state["partitions_created"].append(part["name"])
added_documents[part["name"]].add(doc["file_id"])
restore_state["files_added"] += 1
else:
logger.error(f"Can't add file {doc['file_id']} to partition {part['name']}")
if not dry_run:
logger.error(f"Can't add file {doc['file_id']} to partition {part['name']}")

# Log progress every 100 files
total_files = restore_state["files_added"] + restore_state["files_failed"]
if total_files > 0 and total_files % 100 == 0:
logger.bind(
files_added=restore_state["files_added"],
files_failed=restore_state["files_failed"],
total_processed=total_files,
).info("Restore progress")


def insert_into_vdb(
Expand Down Expand Up @@ -126,6 +152,7 @@ def read_vdb_section(
client: MilvusClient,
batch_size: int,
logger: Any,
restore_state: dict[str, Any],
verbose: bool = False,
dry_run: bool = False,
) -> None:
Expand Down Expand Up @@ -155,6 +182,7 @@ def read_vdb_section(

if len(batch) >= batch_size:
insert_into_vdb(client, collection_name, batch, logger, verbose, dry_run)
restore_state["chunks_inserted"] += len(batch)
batch = []

chunk = json.loads(line)
Expand All @@ -165,6 +193,7 @@ def read_vdb_section(

if len(batch) > 0:
insert_into_vdb(client, collection_name, batch, logger, verbose, dry_run)
restore_state["chunks_inserted"] += len(batch)


def open_backup_file(file_name: str, logger: Any) -> IO[str]:
Expand Down Expand Up @@ -254,15 +283,25 @@ def load_openrag_config(logger: Any) -> tuple[dict[str, Any], dict[str, Any]]:

logger = get_logger()

restore_state = {
"partitions_created": [], # List of partition names created in RDB
"files_added": 0, # Count of files successfully added
"files_failed": 0, # Count of files that failed
"chunks_inserted": 0, # Count of VDB chunks inserted
"errors": [], # List of error dicts: {"file_id", "partition", "error"}
}

try:
# It will create a the Milvus collection if it doesn't exist
vdb_tmp = MilvusDB.options(name="Vectordb", namespace="openrag", lifetime="detached").remote()

await vdb_tmp.__ray_ready__.remote() # ensure the actor is fully initialized and ready: collection and all created if nont existing
await (
vdb_tmp.__ray_ready__.remote()
) # ensure the actor is fully initialized and ready: collection and all created if nont existing
print("VectorDB (Milvus) actor fully initialized")
except Exception as e:
logger.exception(f"Failed while trying to create Milvus collection: {e}")
# TODO: stop execution here
return 1

rdb, vdb = load_openrag_config(logger)

Expand Down Expand Up @@ -306,6 +345,7 @@ def load_openrag_config(logger: Any) -> tuple[dict[str, Any], dict[str, Any]]:
added_documents,
existing_partitions,
logger,
restore_state,
args.user_id,
args.verbose,
args.dry_run,
Expand All @@ -319,11 +359,55 @@ def load_openrag_config(logger: Any) -> tuple[dict[str, Any], dict[str, Any]]:
client,
args.batch_size,
logger,
restore_state,
args.verbose,
args.dry_run,
)

# Log final summary
logger.bind(
partitions_restored=len(restore_state["partitions_created"]),
files_added=restore_state["files_added"],
files_failed=restore_state["files_failed"],
chunks_inserted=restore_state["chunks_inserted"],
).info("Restore completed")

if restore_state["errors"]:
logger.bind(
total_errors=len(restore_state["errors"]),
first_10=restore_state["errors"][:10],
).warning("Restore completed with file-level errors")
except Exception as e:
logger.error("Error: " + str(e))
logger.bind(
error=str(e),
partitions_created=restore_state["partitions_created"],
files_added=restore_state["files_added"],
files_failed=restore_state["files_failed"],
).error("Critical restore failure - initiating rollback")

# Rollback in reverse order: VDB first, then RDB
for partition_name in reversed(restore_state["partitions_created"]):
# 1. Delete from VDB first (no FK constraints, orphaned vectors are worse)
try:
client.delete(
collection_name=vdb["collection_name"],
filter=f'partition == "{partition_name}"',
)
logger.info(f"VDB rollback succeeded for partition: {partition_name}")
except Exception:
logger.exception(f"VDB rollback failed for partition {partition_name}")

# 2. Delete from RDB (cascades to files via FK)
try:
pfm.delete_partition(partition_name)
logger.info(f"RDB rollback succeeded for partition: {partition_name}")
except Exception:
logger.exception(f"RDB rollback failed for partition {partition_name}")

logger.bind(
partitions_rolled_back=len(restore_state["partitions_created"]),
).error("Rollback complete")

raise
finally:
client.close()
Expand Down
3 changes: 2 additions & 1 deletion tests/api_tests/test_health.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ def test_health_check(api_client):
"""Test health check endpoint returns OK."""
response = api_client.get("/health_check")
assert response.status_code == 200
assert "RAG API is up" in response.text
data = response.json()
assert data["status"] in ("healthy", "degraded")


def test_openapi_docs_accessible(api_client):
Expand Down