|
1 | 1 | from datetime import timedelta |
2 | | -from typing import Optional |
3 | | -from fastapi import APIRouter, FastAPI, Query, Request, Response |
| 2 | +import json |
| 3 | +import os |
| 4 | +from typing import List, Optional |
| 5 | +from fastapi import APIRouter, FastAPI, Form, HTTPException, Query, Request, Response, UploadFile |
| 6 | +from fastapi.responses import StreamingResponse |
4 | 7 | from sqlmodel import Session, select |
5 | 8 | from apps.system.crud.assistant import get_assistant_info |
6 | 9 | from apps.system.models.system_model import AssistantModel |
7 | 10 | from apps.system.schemas.auth import CacheName, CacheNamespace |
8 | | -from apps.system.schemas.system_schema import AssistantBase, AssistantDTO, AssistantValidator |
| 11 | +from apps.system.schemas.system_schema import AssistantBase, AssistantDTO, AssistantUiSchema, AssistantValidator |
9 | 12 | from common.core.deps import SessionDep, Trans |
10 | 13 | from common.core.security import create_access_token |
11 | 14 | from common.core.sqlbot_cache import clear_cache |
12 | 15 | from common.utils.time import get_timestamp |
13 | 16 | from starlette.middleware.cors import CORSMiddleware |
14 | 17 | from common.core.config import settings |
15 | 18 | from common.utils.utils import get_origin_from_referer |
| 19 | +from sqlbot_xpack.file_utils import SQLBotFileUtils |
16 | 20 | router = APIRouter(tags=["system/assistant"], prefix="/system/assistant") |
17 | 21 |
|
18 | 22 | @router.get("/info/{id}") |
@@ -48,6 +52,55 @@ async def validator(session: SessionDep, id: int, virtual: Optional[int] = Query |
48 | 52 | ) |
49 | 53 | return AssistantValidator(True, True, True, access_token) |
50 | 54 |
|
| 55 | +@router.get('/picture/{file_id}') |
| 56 | +async def picture(file_id: str): |
| 57 | + file_path = SQLBotFileUtils.get_file_path(file_id=file_id) |
| 58 | + if not os.path.exists(file_path): |
| 59 | + raise HTTPException(status_code=404, detail="File not found") |
| 60 | + |
| 61 | + if file_id.lower().endswith(".svg"): |
| 62 | + media_type = "image/svg+xml" |
| 63 | + else: |
| 64 | + media_type = "image/jpeg" |
| 65 | + |
| 66 | + def iterfile(): |
| 67 | + with open(file_path, mode="rb") as f: |
| 68 | + yield from f |
| 69 | + |
| 70 | + return StreamingResponse(iterfile(), media_type=media_type) |
| 71 | + |
| 72 | +@router.patch('/ui') |
| 73 | +async def ui(session: SessionDep, data: str = Form(), files: List[UploadFile] = []): |
| 74 | + json_data = json.loads(data) |
| 75 | + uiSchema = AssistantUiSchema(**json_data) |
| 76 | + id = uiSchema.id |
| 77 | + db_model = session.get(AssistantModel, id) |
| 78 | + if not db_model: |
| 79 | + raise ValueError(f"AssistantModel with id {id} not found") |
| 80 | + configuration = db_model.configuration |
| 81 | + config_obj = json.loads(configuration) if configuration else {} |
| 82 | + |
| 83 | + ui_schema_dict = uiSchema.model_dump(exclude_none=True, exclude_unset=True) |
| 84 | + if files: |
| 85 | + for file in files: |
| 86 | + origin_file_name = file.filename |
| 87 | + file_name, flag_name = SQLBotFileUtils.split_filename_and_flag(origin_file_name) |
| 88 | + file.filename = file_name |
| 89 | + if flag_name == 'logo' or flag_name == 'float_icon': |
| 90 | + SQLBotFileUtils.check_file(file=file, file_types=[".jpg", ".jpeg", ".png", ".svg"], limit_file_size=(10 * 1024 * 1024)) |
| 91 | + SQLBotFileUtils.detete_file(config_obj.get(flag_name)) |
| 92 | + file_id = await SQLBotFileUtils.upload(file) |
| 93 | + ui_schema_dict[flag_name] = file_id |
| 94 | + else: |
| 95 | + raise ValueError(f"Unsupported file flag: {flag_name}") |
| 96 | + |
| 97 | + for attr, value in ui_schema_dict.items(): |
| 98 | + if attr != 'id' and not attr.startswith("__"): |
| 99 | + config_obj[attr] = value |
| 100 | + |
| 101 | + db_model.configuration = json.dumps(config_obj, ensure_ascii=False) |
| 102 | + session.add(db_model) |
| 103 | + session.commit() |
51 | 104 |
|
52 | 105 | @router.get("", response_model=list[AssistantModel]) |
53 | 106 | async def query(session: SessionDep): |
|
0 commit comments