Skip to content

Commit 61404ee

Browse files
perf: Model config page
1 parent 3c2c124 commit 61404ee

File tree

18 files changed

+446
-275
lines changed

18 files changed

+446
-275
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""019_upgrade_model
2+
3+
Revision ID: dcaecd481715
4+
Revises: 863105882eba
5+
Create Date: 2025-07-04 14:45:56.927188
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = 'dcaecd481715'
15+
down_revision = '863105882eba'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
op.drop_index(op.f('ix_ai_model_id'), table_name='ai_model')
22+
op.drop_table('ai_model')
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.create_table('ai_model',
25+
sa.Column('id', sa.BigInteger(), nullable=False),
26+
sa.Column('api_key', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True),
27+
sa.Column('api_domain', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
28+
sa.Column('protocol', sa.Integer(), nullable=False),
29+
sa.Column('supplier', sa.Integer(), nullable=False),
30+
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
31+
sa.Column('model_type', sa.Integer(), nullable=False),
32+
sa.Column('base_model', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
33+
sa.Column('default_model', sa.Boolean(), nullable=False),
34+
sa.Column('config', sa.Text(), nullable=False),
35+
sa.Column('status', sa.Integer(), nullable=False),
36+
sa.Column('create_time', sa.BigInteger(), nullable=False),
37+
sa.PrimaryKeyConstraint('id')
38+
)
39+
op.create_index(op.f('ix_ai_model_id'), 'ai_model', ['id'], unique=False)
40+
# ### end Alembic commands ###
41+
42+
def downgrade():
43+
# ### commands auto generated by Alembic - please adjust! ###
44+
op.drop_index(op.f('ix_ai_model_id'), table_name='ai_model')
45+
op.drop_table('ai_model')
46+
47+
op.create_table('ai_model',
48+
sa.Column('id', sa.BigInteger(), nullable=False),
49+
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
50+
sa.Column('type', sa.Integer(), nullable=False),
51+
sa.Column('api_key', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True),
52+
sa.Column('endpoint', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
53+
sa.Column('max_context_window', sa.Integer(), nullable=False),
54+
sa.Column('temperature', sa.Float(), nullable=False),
55+
sa.Column('status', sa.Boolean(), nullable=False),
56+
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True),
57+
sa.Column('create_time', sa.BigInteger(), nullable=False),
58+
sa.PrimaryKeyConstraint('id')
59+
)
60+
op.create_index(op.f('ix_ai_model_id'), 'ai_model', ['id'], unique=False)
61+
# ### end Alembic commands ###

backend/apps/ai_model/model_factory.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
import json
12
from pydantic import BaseModel
23
from typing import Optional, Dict, Any, Type
34
from abc import ABC, abstractmethod
45
from langchain_core.language_models import BaseLLM as LangchainBaseLLM
56
from langchain_openai import ChatOpenAI
7+
from sqlmodel import Session, select
68

9+
from common.core.db import engine
710
from apps.system.models.system_model import AiModelDetail
811

912

1013
# from langchain_community.llms import Tongyi, VLLM
1114

1215
class LLMConfig(BaseModel):
1316
"""Base configuration class for large language models"""
17+
model_id: Optional[int] = None
1418
model_type: str # Model type: openai/tongyi/vllm etc.
1519
model_name: str # Specific model name
1620
api_key: Optional[str] = None
@@ -93,12 +97,39 @@ def register_llm(cls, model_type: str, llm_class: Type[BaseLLM]):
9397

9498

9599
# todo
96-
def get_llm_config(aimodel: AiModelDetail) -> LLMConfig:
100+
""" def get_llm_config(aimodel: AiModelDetail) -> LLMConfig:
97101
config = LLMConfig(
98102
model_type="openai",
99103
model_name=aimodel.name,
100104
api_key=aimodel.api_key,
101105
api_base_url=aimodel.endpoint,
102106
additional_params={"temperature": aimodel.temperature}
103107
)
104-
return config
108+
return config """
109+
110+
def get_default_config() -> LLMConfig:
111+
with Session(engine) as session:
112+
db_model = session.exec(
113+
select(AiModelDetail).where(AiModelDetail.default_model == True)
114+
).first()
115+
if not db_model:
116+
raise ValueError("The system default model has not been set")
117+
118+
additional_params = {}
119+
if db_model.config:
120+
try:
121+
config_raw = json.loads(db_model.config)
122+
additional_params = {item["key"]: item["val"] for item in config_raw if "key" in item and "val" in item}
123+
except Exception:
124+
pass
125+
126+
# 构造 LLMConfig
127+
return LLMConfig(
128+
model_id=db_model.id,
129+
model_type="openai",
130+
model_name=db_model.base_model,
131+
api_key=db_model.api_key,
132+
api_base_url=db_model.api_domain,
133+
additional_params=additional_params
134+
)
135+

backend/apps/chat/task/llm.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sqlalchemy import select
1212
from sqlalchemy.orm import load_only
1313

14-
from apps.ai_model.model_factory import LLMConfig, LLMFactory, get_llm_config
14+
from apps.ai_model.model_factory import LLMConfig, LLMFactory, get_default_config
1515
from apps.chat.curd.chat import save_question, save_full_sql_message, save_full_sql_message_and_answer, save_sql, \
1616
save_error_message, save_sql_exec_data, save_full_chart_message, save_full_chart_message_and_answer, save_chart, \
1717
finish_record, save_full_analysis_message_and_answer, save_full_predict_message_and_answer, save_predict_data, \
@@ -21,7 +21,6 @@
2121
from apps.datasource.models.datasource import CoreDatasource
2222
from apps.db.db import exec_sql
2323
from apps.system.crud.user import get_user_info
24-
from apps.system.models.system_model import AiModelDetail
2524
from common.core.deps import SessionDep, CurrentUser
2625

2726
warnings.filterwarnings("ignore")
@@ -58,13 +57,6 @@ def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question
5857

5958
chat_question.engine = ds.type_name if ds.type != 'excel' else 'PostgreSQL'
6059

61-
# Get available AI model
62-
aimodel = self.session.exec(select(AiModelDetail).where(
63-
AiModelDetail.status == True,
64-
AiModelDetail.api_key.is_not(None)
65-
)).first()
66-
if not aimodel and aimodel[0]:
67-
raise Exception("No available AI model configuration found")
6860

6961
history_records: List[ChatRecord] = list(
7062
map(lambda x: ChatRecord(**x.model_dump()), filter(lambda r: True if r.first_chat != True else False,
@@ -79,8 +71,9 @@ def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question
7971

8072
self.ds = CoreDatasource(**ds.model_dump()) if ds else None
8173
self.chat_question = chat_question
82-
self.chat_question.ai_modal_id = aimodel[0].id
83-
self.config = get_llm_config(aimodel[0])
74+
self.config = get_default_config()
75+
self.chat_question.ai_modal_id = self.config.model_id
76+
8477

8578
# Create LLM instance through factory
8679
llm_instance = LLMFactory.create_llm(self.config)

backend/apps/system/api/aimodel.py

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,90 @@
1-
from fastapi import APIRouter
2-
from sqlmodel import select
1+
import json
2+
from typing import List, Union
3+
from apps.system.schemas.ai_model_schema import AiModelConfigItem, AiModelCreator, AiModelEditor, AiModelGridItem
4+
from fastapi import APIRouter, Query
5+
from sqlmodel import func, select
36

47
from apps.system.models.system_model import AiModelDetail
5-
from apps.system.schemas.system_schema import model_status
68
from common.core.deps import SessionDep
7-
from common.core.pagination import Paginator
8-
from common.core.schemas import PaginatedResponse, PaginationParams
99
from common.utils.time import get_timestamp
1010

1111
router = APIRouter(tags=["system/aimodel"], prefix="/system/aimodel")
1212

13-
14-
@router.get("/pager/{pageNum}/{pageSize}", response_model=PaginatedResponse[AiModelDetail])
15-
async def pager(
13+
@router.get("", response_model=list[AiModelGridItem])
14+
async def query(
1615
session: SessionDep,
17-
pageNum: int,
18-
pageSize: int
16+
keyword: Union[str, None] = Query(default=None, max_length=255)
1917
):
20-
pagination = PaginationParams(page=pageNum, size=pageSize)
21-
paginator = Paginator(session)
22-
filters = {}
23-
return await paginator.get_paginated_response(
24-
model=AiModelDetail,
25-
pagination=pagination,
26-
**filters)
27-
18+
statement = select(AiModelDetail.id,
19+
AiModelDetail.name,
20+
AiModelDetail.model_type,
21+
AiModelDetail.base_model,
22+
AiModelDetail.supplier,
23+
AiModelDetail.default_model)
24+
if keyword is not None:
25+
statement = statement.where(AiModelDetail.name.like(f"%{keyword}%"))
26+
27+
items = session.exec(statement).all()
28+
return items
2829

29-
@router.get("/{id}", response_model=AiModelDetail)
30+
@router.get("/{id}", response_model=AiModelEditor)
3031
async def get_model_by_id(
3132
session: SessionDep,
3233
id: int
3334
):
34-
term = session.get(AiModelDetail, id)
35-
return term
35+
db_model = session.get(AiModelDetail, id)
36+
if not db_model:
37+
raise ValueError(f"AiModelDetail with id {id} not found")
3638

39+
config_list: List[AiModelConfigItem] = []
40+
if db_model.config:
41+
try:
42+
raw = json.loads(db_model.config)
43+
config_list = [AiModelConfigItem(**item) for item in raw]
44+
except Exception:
45+
pass
46+
data = AiModelDetail.model_validate(db_model).model_dump(exclude_unset=True)
47+
data.pop("config", None)
48+
data["config_list"] = config_list
49+
return AiModelEditor(**data)
3750

38-
@router.post("", response_model=AiModelDetail)
51+
@router.post("")
3952
async def add_model(
4053
session: SessionDep,
41-
creator: AiModelDetail
54+
creator: AiModelCreator
4255
):
43-
data = AiModelDetail.model_validate(creator)
44-
data.create_time = get_timestamp()
45-
session.add(data)
56+
data = creator.model_dump(exclude_unset=True)
57+
data["config"] = json.dumps([item.model_dump(exclude_unset=True) for item in creator.config_list])
58+
data.pop("config_list", None)
59+
detail = AiModelDetail.model_validate(data)
60+
detail.create_time = get_timestamp()
61+
count = session.exec(select(func.count(AiModelDetail.id))).one()
62+
if count == 0:
63+
detail.default_model = True
64+
session.add(detail)
4665
session.commit()
47-
return creator
4866

49-
50-
@router.put("", response_model=AiModelDetail)
51-
async def update_terminology(
67+
@router.put("")
68+
async def update_model(
5269
session: SessionDep,
53-
model: AiModelDetail
70+
editor: AiModelEditor
5471
):
55-
model.id = int(model.id)
56-
term = session.exec(select(AiModelDetail).where(AiModelDetail.id == model.id)).first()
57-
update_data = model.model_dump(exclude_unset=True)
58-
for field, value in update_data.items():
59-
setattr(term, field, value)
60-
session.add(term)
72+
id = int(editor.id)
73+
data = editor.model_dump(exclude_unset=True)
74+
data["config"] = json.dumps([item.model_dump(exclude_unset=True) for item in editor.config_list])
75+
data.pop("config_list", None)
76+
db_model = session.get(AiModelDetail, id)
77+
update_data = AiModelDetail.model_validate(data)
78+
db_model.sqlmodel_update(update_data)
79+
session.add(db_model)
6180
session.commit()
62-
return model
63-
6481

65-
@router.delete("/{id}", response_model=AiModelDetail)
66-
async def delete_terminology(
82+
@router.delete("/{id}")
83+
async def delete_model(
6784
session: SessionDep,
6885
id: int
6986
):
70-
term = session.exec(select(AiModelDetail).where(AiModelDetail.id == id)).first()
71-
session.delete(term)
72-
session.commit()
73-
return {
74-
"message": f"AiModel with ID {id} deleted successfully."
75-
}
76-
77-
78-
@router.patch("/status", response_model=dict)
79-
async def status(session: SessionDep, dto: model_status):
80-
ids = dto.ids
81-
status = dto.status
82-
if not ids:
83-
return {"message": "ids is empty"}
84-
statement = select(AiModelDetail).where(AiModelDetail.id.in_(ids))
85-
terms = session.exec(statement).all()
86-
for term in terms:
87-
term.status = status
88-
session.add(term)
87+
item = session.get(AiModelDetail, id)
88+
session.delete(item)
8989
session.commit()
90-
return {"message": f"AiModel with IDs {ids} updated successfully."}
90+
Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1-
2-
from sqlmodel import BigInteger, Field, SQLModel
1+
2+
from sqlmodel import BigInteger, Field, Text, SQLModel
33
from common.core.models import SnowflakeBase
44

55

66
class AiModelBase:
7+
supplier: int = Field(nullable=False)
78
name: str = Field(max_length=255, nullable=False)
8-
type: int = Field(nullable=False)
9+
model_type: int = Field(nullable=False)
10+
base_model: str = Field(max_length = 255, nullable=False)
11+
default_model: bool = Field(default=False, nullable=False)
912

10-
class AiModelDetail(AiModelBase, SnowflakeBase, table=True):
13+
class AiModelDetail(SnowflakeBase, AiModelBase, table=True):
1114
__tablename__ = "ai_model"
1215
api_key: str | None = Field(max_length=255, nullable=True)
13-
endpoint: str = Field(max_length=255, nullable=False)
14-
max_context_window: int = Field(default=0)
15-
temperature: float = Field(default=0.0)
16-
status: bool = Field(default=True)
17-
description: str | None = Field(max_length=255, nullable=True)
16+
api_domain: str = Field(max_length=255, nullable=False)
17+
protocol: int = Field(nullable=False, default = 1)
18+
config: str = Field(sa_type = Text())
19+
status: int = Field(nullable=False, default = 1)
1820
create_time: int = Field(default=0, sa_type=BigInteger())
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
from typing import List
3+
from pydantic import BaseModel
4+
5+
from common.core.schemas import BaseCreatorDTO
6+
7+
class AiModelItem(BaseModel):
8+
name: str
9+
model_type: int
10+
base_model: str
11+
supplier: int
12+
default_model: bool = False
13+
14+
class AiModelGridItem(AiModelItem, BaseCreatorDTO):
15+
pass
16+
17+
class AiModelConfigItem(BaseModel):
18+
key: str
19+
val: object
20+
name: str
21+
22+
class AiModelCreator(AiModelItem):
23+
api_domain: str
24+
api_key: str
25+
config_list: List[AiModelConfigItem]
26+
27+
class AiModelEditor(AiModelCreator, BaseCreatorDTO):
28+
pass

0 commit comments

Comments
 (0)