|
1 | 1 | import os |
| 2 | +from typing import Dict, Any |
2 | 3 |
|
3 | 4 | import sqlbot_xpack |
4 | 5 | from alembic.config import Config |
5 | | -from fastapi import FastAPI |
| 6 | +from fastapi import FastAPI, Request |
6 | 7 | from fastapi.concurrency import asynccontextmanager |
| 8 | +from fastapi.openapi.utils import get_openapi |
| 9 | +from fastapi.responses import JSONResponse |
7 | 10 | from fastapi.routing import APIRoute |
8 | 11 | from fastapi.staticfiles import StaticFiles |
9 | 12 | from fastapi_mcp import FastApiMCP |
|
12 | 15 |
|
13 | 16 | from alembic import command |
14 | 17 | from apps.api import api_router |
15 | | -from common.utils.embedding_threads import fill_empty_table_and_ds_embeddings |
| 18 | +from apps.swagger.i18n import PLACEHOLDER_PREFIX, tags_metadata |
| 19 | +from apps.swagger.i18n import get_translation, DEFAULT_LANG |
16 | 20 | from apps.system.crud.aimodel_manage import async_model_info |
17 | 21 | from apps.system.crud.assistant import init_dynamic_cors |
18 | 22 | from apps.system.middleware.auth import TokenMiddleware |
19 | 23 | from common.core.config import settings |
20 | 24 | from common.core.response_middleware import ResponseMiddleware, exception_handler |
21 | 25 | from common.core.sqlbot_cache import init_sqlbot_cache |
22 | | -from common.utils.embedding_threads import fill_empty_terminology_embeddings, fill_empty_data_training_embeddings |
| 26 | +from common.utils.embedding_threads import fill_empty_terminology_embeddings, fill_empty_data_training_embeddings, \ |
| 27 | + fill_empty_table_and_ds_embeddings |
23 | 28 | from common.utils.utils import SQLBotLogUtil |
24 | 29 |
|
25 | 30 |
|
@@ -65,9 +70,104 @@ def custom_generate_unique_id(route: APIRoute) -> str: |
65 | 70 | title=settings.PROJECT_NAME, |
66 | 71 | openapi_url=f"{settings.API_V1_STR}/openapi.json", |
67 | 72 | generate_unique_id_function=custom_generate_unique_id, |
68 | | - lifespan=lifespan |
| 73 | + lifespan=lifespan, |
| 74 | + docs_url=None, |
| 75 | + redoc_url=None |
69 | 76 | ) |
70 | 77 |
|
| 78 | +# cache docs for different text |
| 79 | +_openapi_cache: Dict[str, Dict[str, Any]] = {} |
| 80 | + |
| 81 | +# replace placeholder |
| 82 | +def replace_placeholders_in_schema(schema: Dict[str, Any], trans: Dict[str, str]) -> None: |
| 83 | + """ |
| 84 | + search OpenAPI schema,replace PLACEHOLDER_xxx to text。 |
| 85 | + """ |
| 86 | + if isinstance(schema, dict): |
| 87 | + for key, value in schema.items(): |
| 88 | + if isinstance(value, str) and value.startswith(PLACEHOLDER_PREFIX): |
| 89 | + placeholder_key = value[len(PLACEHOLDER_PREFIX):] |
| 90 | + schema[key] = trans.get(placeholder_key, value) |
| 91 | + else: |
| 92 | + replace_placeholders_in_schema(value, trans) |
| 93 | + elif isinstance(schema, list): |
| 94 | + for item in schema: |
| 95 | + replace_placeholders_in_schema(item, trans) |
| 96 | + |
| 97 | + |
| 98 | + |
| 99 | +# OpenAPI build |
| 100 | +def get_language_from_request(request: Request) -> str: |
| 101 | + # get param from query ?lang=zh |
| 102 | + lang = request.query_params.get("lang") |
| 103 | + if lang in ["en", "zh"]: |
| 104 | + return lang |
| 105 | + # get lang from Accept-Language Header |
| 106 | + accept_lang = request.headers.get("accept-language", "") |
| 107 | + if "zh" in accept_lang.lower(): |
| 108 | + return "zh" |
| 109 | + return DEFAULT_LANG |
| 110 | + |
| 111 | + |
| 112 | +def generate_openapi_for_lang(lang: str) -> Dict[str, Any]: |
| 113 | + if lang in _openapi_cache: |
| 114 | + return _openapi_cache[lang] |
| 115 | + |
| 116 | + # tags metadata |
| 117 | + trans = get_translation(lang) |
| 118 | + localized_tags = [] |
| 119 | + for tag in tags_metadata: |
| 120 | + desc = tag["description"] |
| 121 | + if desc.startswith(PLACEHOLDER_PREFIX): |
| 122 | + key = desc[len(PLACEHOLDER_PREFIX):] |
| 123 | + desc = trans.get(key, desc) |
| 124 | + localized_tags.append({ |
| 125 | + "name": tag["name"], |
| 126 | + "description": desc |
| 127 | + }) |
| 128 | + |
| 129 | + # 1. create OpenAPI |
| 130 | + openapi_schema = get_openapi( |
| 131 | + title="SQLBot API Document" if lang == "en" else "SQLBot API 文档", |
| 132 | + version="1.0.0", |
| 133 | + routes=app.routes, |
| 134 | + tags=localized_tags |
| 135 | + ) |
| 136 | + |
| 137 | + # openapi version |
| 138 | + openapi_schema.setdefault("openapi", "3.1.0") |
| 139 | + |
| 140 | + # 2. get trans for lang |
| 141 | + trans = get_translation(lang) |
| 142 | + |
| 143 | + # 3. replace placeholder |
| 144 | + replace_placeholders_in_schema(openapi_schema, trans) |
| 145 | + |
| 146 | + # 4. cache |
| 147 | + _openapi_cache[lang] = openapi_schema |
| 148 | + return openapi_schema |
| 149 | + |
| 150 | + |
| 151 | + |
| 152 | +# custom /openapi.json and /docs |
| 153 | +@app.get("/openapi.json", include_in_schema=False) |
| 154 | +async def custom_openapi(request: Request): |
| 155 | + lang = get_language_from_request(request) |
| 156 | + schema = generate_openapi_for_lang(lang) |
| 157 | + return JSONResponse(schema) |
| 158 | + |
| 159 | + |
| 160 | +@app.get("/docs", include_in_schema=False) |
| 161 | +async def custom_swagger_ui(request: Request): |
| 162 | + lang = get_language_from_request(request) |
| 163 | + from fastapi.openapi.docs import get_swagger_ui_html |
| 164 | + return get_swagger_ui_html( |
| 165 | + openapi_url=f"/openapi.json?lang={lang}", |
| 166 | + title="SQLBot API Docs", |
| 167 | + swagger_favicon_url="https://fastapi.tiangolo.com/img/favicon.png", |
| 168 | + ) |
| 169 | + |
| 170 | + |
71 | 171 | mcp_app = FastAPI() |
72 | 172 | # mcp server, images path |
73 | 173 | images_path = settings.MCP_IMAGE_PATH |
|
0 commit comments