-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathapi.py
More file actions
263 lines (203 loc) · 8.61 KB
/
api.py
File metadata and controls
263 lines (203 loc) · 8.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import os
import warnings
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum
from importlib.metadata import version as get_package_version
from pathlib import Path
import ray
import uvicorn
from config import load_config
from dotenv import dotenv_values
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
ray.init(dashboard_host="0.0.0.0")
# Apply noqa: E402 to ignore "module level import not at top of file" cause ray.init has to be called first
# flake8: noqa: E402
from routers.actors import router as actors_router
from routers.extract import router as extract_router
from routers.indexer import router as indexer_router
from routers.openai import router as openai_router
from routers.partition import router as partition_router
from routers.queue import router as queue_router
from routers.search import router as search_router
from routers.tools import router as tools_router
from routers.users import router as users_router
from mcp_server import create_mcp_http_app
from mcp_server import path as mcp_path
from mcp_server import server as mcp_server
from starlette.middleware.base import BaseHTTPMiddleware
from utils.dependencies import get_vectordb
from utils.exceptions import OpenRAGError
from utils.logger import get_logger
# Filter SyntaxWarning from pydub (invalid escape sequences in regex)
# This is a known issue in pydub 0.25.1 that hasn't been fixed upstream
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pydub")
SHARED_ENV = os.environ.get("SHARED_ENV", None)
env_vars = dotenv_values(SHARED_ENV) if SHARED_ENV else {}
env_vars["PYTHONPATH"] = "/app/openrag"
logger = get_logger()
config = load_config()
DATA_DIR = Path(config.paths.data_dir)
class Tags(Enum):
VDB = "VectorDB operations"
INDEXER = ("Indexer",)
SEARCH = ("Semantic Search",)
OPENAI = ("OpenAI Compatible API",)
EXTRACT = ("Document extracts",)
PARTITION = ("Partitions & files",)
QUEUE = ("Queue management",)
ACTORS = ("Ray Actors",)
USERS = ("User management",)
TOOLS = ("Tools",)
class AppState:
def __init__(self, config):
self.config = config
self.data_dir = Path(config.paths.data_dir)
# Read the token from env (or None if not set)
AUTH_TOKEN: str | None = os.getenv("AUTH_TOKEN")
INDEXERUI_PORT: str | None = os.getenv("INDEXERUI_PORT", "3042")
INDEXERUI_URL: str | None = os.getenv("INDEXERUI_URL", f"http://localhost:{INDEXERUI_PORT}")
WITH_CHAINLIT_UI: bool = os.getenv("WITH_CHAINLIT_UI", "true").lower() == "true"
WITH_OPENAI_API: bool = os.getenv("WITH_OPENAI_API", "true").lower() == "true"
try:
app_version = get_package_version("openrag")
except Exception:
app_version = "unknown"
_mcp_lifespan_stack = AsyncExitStack()
@asynccontextmanager
async def lifespan(app):
await _mcp_lifespan_stack.enter_async_context(mcp_server.session_manager.run())
yield
await _mcp_lifespan_stack.aclose()
app = FastAPI(version=app_version, lifespan=lifespan)
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="Openrag API",
version=app.version,
routes=app.routes,
)
# Add global security
openapi_schema["components"]["securitySchemes"] = {"BearerAuth": {"type": "http", "scheme": "bearer"}}
openapi_schema["security"] = [{"BearerAuth": []}]
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
vectordb = get_vectordb()
# Skip if no AUTH_TOKEN configured
if AUTH_TOKEN is None:
user = await vectordb.get_user.remote(1)
user_partitions = await vectordb.list_user_partitions.remote(1)
request.state.user = user
request.state.user_partitions = user_partitions
return await call_next(request)
# routes to allow access to without token bearer
if request.url.path in [
"/docs",
"/openapi.json",
"/redoc",
"/health_check",
"/version",
] or request.url.path.startswith("/chainlit"): # Allow all chainlit subroutes
return await call_next(request)
# Extract token
token = None
# For /static routes, allow token via query parameter (this easy file viewing with a link without a bearer)
# usage http://localhost:8080/static?token=api_key
if request.url.path.startswith("/static"):
token = request.query_params.get("token", "")
else:
# For all other routes, require Bearer header
# # Extract Bearer token
auth = request.headers.get("authorization", "")
if auth and auth.lower().startswith("bearer "):
token = auth.split(" ", 1)[1]
if not token:
return JSONResponse(status_code=403, content={"detail": "Missing token"})
# Lookup user in DB
user = await vectordb.get_user_by_token.remote(token)
if not user:
return JSONResponse(status_code=403, content={"detail": "Invalid token"})
# Load user partitions
user_partitions = await vectordb.list_user_partitions.remote(user["id"])
# Attach to request
request.state.user = user
request.state.user_partitions = user_partitions
return await call_next(request)
# Register once
app.add_middleware(AuthMiddleware)
# Exception handlers
@app.exception_handler(OpenRAGError)
async def openrag_exception_handler(request: Request, exc: OpenRAGError):
logger = get_logger()
logger.error("OpenRAGError occurred", error=str(exc))
return JSONResponse(status_code=exc.status_code, content=exc.to_dict())
# Add CORS middleware
allow_origins = [
"http://localhost:3042",
"http://localhost:5173",
INDEXERUI_URL,
]
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins, # Adjust as needed for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["Mcp-Session-Id"],
)
app.state.app_state = AppState(config)
app.mount("/static", StaticFiles(directory=DATA_DIR.resolve(), check_dir=True), name="static")
app.mount(mcp_path, create_mcp_http_app(), name="mcp")
@app.get("/health_check", summary="Health check endpoint for API", dependencies=[])
async def health_check(request: Request):
# TODO : Error reporting about llm and vlm
return "RAG API is up."
@app.get("/version", summary="Get openRAG version", dependencies=[])
def get_version():
return {"version": app.version}
# Mount the indexer router
app.include_router(indexer_router, prefix="/indexer", tags=[Tags.INDEXER])
# Mount the extract router
app.include_router(extract_router, prefix="/extract", tags=[Tags.EXTRACT])
# Mount the search router
app.include_router(search_router, prefix="/search", tags=[Tags.SEARCH])
# Mount the partition router
app.include_router(partition_router, prefix="/partition", tags=[Tags.PARTITION])
# Mount the queue router
app.include_router(queue_router, prefix="/queue", tags=[Tags.QUEUE])
# Mount the actors router
app.include_router(actors_router, prefix="/actors", tags=[Tags.ACTORS])
# Mount the users router
app.include_router(users_router, prefix="/users", tags=[Tags.USERS])
app.include_router(tools_router, prefix="/v1", tags=[Tags.TOOLS])
if WITH_OPENAI_API:
# Mount the openai router
app.include_router(openai_router, prefix="/v1", tags=[Tags.OPENAI])
if WITH_CHAINLIT_UI:
# Mount the default front
from chainlit.utils import mount_chainlit
mount_chainlit(app, "./app_front.py", path="/chainlit")
app.include_router(openai_router, prefix="/v1", tags=[Tags.OPENAI]) # cause chainlit uses openai api endpoints
if __name__ == "__main__":
if config.ray.serve.enable:
from ray import serve
@serve.deployment(num_replicas=config.ray.serve.num_replicas)
@serve.ingress(app)
class OpenRagAPI:
pass
serve.start(http_options={"host": config.ray.serve.host, "port": config.ray.serve.port})
if WITH_CHAINLIT_UI:
from chainlit_api import app as chainlit_app
serve.run(OpenRagAPI.bind(), route_prefix="/")
uvicorn.run(chainlit_app, host="0.0.0.0", port=config.ray.serve.chainlit_port)
else:
serve.run(OpenRagAPI.bind(), route_prefix="/", blocking=True)
else:
uvicorn.run("api:app", host="0.0.0.0", port=8080, reload=True, proxy_headers=True)