|
2 | 2 | import json |
3 | 3 | import os |
4 | 4 | from typing import List, Optional |
5 | | -from fastapi import APIRouter, FastAPI, Form, HTTPException, Query, Request, Response, UploadFile |
| 5 | +from fastapi import APIRouter, Form, HTTPException, Query, Request, Response, UploadFile |
6 | 6 | from fastapi.responses import StreamingResponse |
7 | | -from sqlmodel import Session, select |
| 7 | +from sqlmodel import select |
8 | 8 | from apps.system.crud.assistant import get_assistant_info |
| 9 | +from apps.system.crud.assistant_manage import dynamic_upgrade_cors, save |
9 | 10 | from apps.system.models.system_model import AssistantModel |
10 | 11 | from apps.system.schemas.auth import CacheName, CacheNamespace |
11 | 12 | from apps.system.schemas.system_schema import AssistantBase, AssistantDTO, AssistantUiSchema, AssistantValidator |
12 | 13 | from common.core.deps import SessionDep, Trans |
13 | 14 | from common.core.security import create_access_token |
14 | 15 | from common.core.sqlbot_cache import clear_cache |
15 | 16 | from common.utils.time import get_timestamp |
16 | | -from starlette.middleware.cors import CORSMiddleware |
| 17 | + |
17 | 18 | from common.core.config import settings |
18 | 19 | from common.utils.utils import get_origin_from_referer |
19 | 20 | from sqlbot_xpack.file_utils import SQLBotFileUtils |
| 21 | + |
20 | 22 | router = APIRouter(tags=["system/assistant"], prefix="/system/assistant") |
21 | 23 |
|
22 | 24 | @router.get("/info/{id}") |
@@ -104,16 +106,12 @@ async def ui(session: SessionDep, data: str = Form(), files: List[UploadFile] = |
104 | 106 |
|
105 | 107 | @router.get("", response_model=list[AssistantModel]) |
106 | 108 | async def query(session: SessionDep): |
107 | | - list_result = session.exec(select(AssistantModel).order_by(AssistantModel.name, AssistantModel.create_time)).all() |
| 109 | + list_result = session.exec(select(AssistantModel).where(AssistantModel.type.in_([0, 1])).order_by(AssistantModel.name, AssistantModel.create_time)).all() |
108 | 110 | return list_result |
109 | 111 |
|
110 | 112 | @router.post("") |
111 | 113 | async def add(request: Request, session: SessionDep, creator: AssistantBase): |
112 | | - db_model = AssistantModel.model_validate(creator) |
113 | | - db_model.create_time = get_timestamp() |
114 | | - session.add(db_model) |
115 | | - session.commit() |
116 | | - dynamic_upgrade_cors(request=request, session=session) |
| 114 | + save(request, session, creator) |
117 | 115 |
|
118 | 116 |
|
119 | 117 | @router.put("") |
@@ -147,26 +145,7 @@ async def delete(request: Request, session: SessionDep, id: int): |
147 | 145 | session.commit() |
148 | 146 | dynamic_upgrade_cors(request=request, session=session) |
149 | 147 |
|
150 | | -def dynamic_upgrade_cors(request: Request, session: Session): |
151 | | - list_result = session.exec(select(AssistantModel).order_by(AssistantModel.create_time)).all() |
152 | | - seen = set() |
153 | | - unique_domains = [] |
154 | | - for item in list_result: |
155 | | - if item.domain: |
156 | | - for domain in item.domain.split(','): |
157 | | - domain = domain.strip() |
158 | | - if domain and domain not in seen: |
159 | | - seen.add(domain) |
160 | | - unique_domains.append(domain) |
161 | | - app: FastAPI = request.app |
162 | | - cors_middleware = None |
163 | | - for middleware in app.user_middleware: |
164 | | - if middleware.cls == CORSMiddleware: |
165 | | - cors_middleware = middleware |
166 | | - break |
167 | | - if cors_middleware: |
168 | | - updated_origins = list(set(settings.all_cors_origins + unique_domains)) |
169 | | - cors_middleware.kwargs['allow_origins'] = updated_origins |
| 148 | + |
170 | 149 |
|
171 | 150 |
|
172 | 151 |
|
|
0 commit comments