|
5 | 5 | import time |
6 | 6 | from collections.abc import Iterable |
7 | 7 | from contextlib import asynccontextmanager, suppress |
| 8 | +from importlib.metadata import PackageNotFoundError, version |
8 | 9 |
|
9 | 10 | import procrastinate |
10 | | -from fastapi import APIRouter, FastAPI, HTTPException |
| 11 | +from fastapi import APIRouter, FastAPI, HTTPException, Request |
11 | 12 | from fastapi.exception_handlers import http_exception_handler |
12 | | -from fastapi.responses import ORJSONResponse |
| 13 | +from fastapi.openapi.docs import get_swagger_ui_html |
| 14 | +from fastapi.openapi.utils import get_openapi |
| 15 | +from fastapi.responses import JSONResponse, ORJSONResponse |
13 | 16 | from kink import Container, di, inject |
14 | 17 | from opentelemetry.metrics import CallbackOptions, Observation, get_meter |
15 | 18 | from procrastinate.exceptions import AlreadyEnqueued |
16 | | -from starlette.requests import Request |
17 | 19 | from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_500_INTERNAL_SERVER_ERROR |
18 | 20 |
|
19 | 21 | from agentstack_server.api.routes.a2a import router as a2a_router |
|
46 | 48 | logger = logging.getLogger(__name__) |
47 | 49 |
|
48 | 50 |
|
| 51 | +def get_version(): |
| 52 | + try: |
| 53 | + __version__ = version("agentstack-server") |
| 54 | + except PackageNotFoundError: |
| 55 | + __version__ = "0.1.0" |
| 56 | + |
| 57 | + return __version__ |
| 58 | + |
| 59 | + |
49 | 60 | def extract_messages(exc): |
50 | 61 | if isinstance(exc, BaseExceptionGroup): |
51 | 62 | return [(exc_type, msg) for e in exc.exceptions for exc_type, msg in extract_messages(e)] |
@@ -111,6 +122,28 @@ def mount_routes(app: FastAPI): |
111 | 122 | app.include_router(server_router, prefix="/api/v1", tags=["provider"]) |
112 | 123 | app.include_router(well_known_router, prefix="/.well-known", tags=["well-known"]) |
113 | 124 |
|
| 125 | + @app.get("/api/v1/openapi.json", include_in_schema=False) |
| 126 | + async def custom_openapi(request: Request): |
| 127 | + openapi_schema = get_openapi( |
| 128 | + title="Agentstack server", |
| 129 | + version=get_version(), |
| 130 | + routes=app.routes, |
| 131 | + ) |
| 132 | + |
| 133 | + base_url = str(request.base_url) |
| 134 | + openapi_schema["servers"] = [{"url": base_url}] |
| 135 | + |
| 136 | + return JSONResponse(openapi_schema) |
| 137 | + |
| 138 | + @app.get("/api/v1/docs", include_in_schema=False) |
| 139 | + async def custom_docs(request: Request): |
| 140 | + openapi_url = request.url_for(custom_openapi.__name__) |
| 141 | + |
| 142 | + return get_swagger_ui_html( |
| 143 | + openapi_url=openapi_url, |
| 144 | + title="BeeAI Platform API Docs", |
| 145 | + ) |
| 146 | + |
114 | 147 | @app.get("/healthcheck") |
115 | 148 | async def healthcheck(): |
116 | 149 | return "OK" |
@@ -169,9 +202,9 @@ async def lifespan(_app: FastAPI, procrastinate_app: procrastinate.App, mcp_serv |
169 | 202 | app = FastAPI( |
170 | 203 | lifespan=lifespan, |
171 | 204 | default_response_class=ORJSONResponse, # better performance then default + handle NaN floats |
172 | | - docs_url="/api/v1/docs", |
173 | | - openapi_url="/api/v1/openapi.json", |
174 | | - servers=[{"url": f"http://localhost:{configuration.port}"}], |
| 205 | + docs_url=None, |
| 206 | + openapi_url=None, |
| 207 | + servers=None, |
175 | 208 | ) |
176 | 209 |
|
177 | 210 | logger.info("Mounting routes...") |
|
0 commit comments