Skip to content

Commit c7c7416

Browse files
perf: Chat and model
1 parent c41c760 commit c7c7416

File tree

16 files changed

+332
-13
lines changed

16 files changed

+332
-13
lines changed

backend/apps/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from apps.system.api import login, user, aimodel
44
from apps.settings.api import terminology
55
from apps.datasource.api import datasource
6+
from apps.chat.api import chat
67

78

89
api_router = APIRouter()
@@ -11,5 +12,6 @@
1112
api_router.include_router(aimodel.router)
1213
api_router.include_router(terminology.router)
1314
api_router.include_router(datasource.router)
15+
api_router.include_router(chat.router)
1416

1517

backend/apps/chat/__init__.py

Whitespace-only changes.

backend/apps/chat/api/__init__.py

Whitespace-only changes.

backend/apps/chat/api/chat.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from fastapi import APIRouter, HTTPException
2+
from sqlmodel import select
3+
from apps.chat.schemas.chat_base_schema import LLMConfig
4+
from apps.chat.schemas.chat_schema import ChatQuestion
5+
from apps.chat.schemas.llm import LLMService
6+
from apps.system.models.system_modle import AiModelDetail
7+
from common.core.deps import SessionDep
8+
# from sse_starlette.sse import EventSourceResponse
9+
router = APIRouter(tags=["Data Q&A"], prefix="/chat")
10+
11+
12+
@router.post("/question")
13+
async def stream_sql(session: SessionDep, requestQuestion: ChatQuestion):
14+
question = requestQuestion.question
15+
16+
# Use OpenAI model
17+
""" openai_config = LLMConfig(
18+
model_type="openai",
19+
model_name="gpt-4",
20+
api_key="your-api-key",
21+
additional_params={"temperature": 0.7}
22+
)
23+
openai_service = LLMService(openai_config) """
24+
25+
aimodel = session.exec(select(AiModelDetail).where(AiModelDetail.status == True, AiModelDetail.api_key.is_not(None))).first()
26+
27+
if not aimodel:
28+
raise HTTPException(
29+
status_code=400,
30+
detail="No available AI model configuration found"
31+
)
32+
33+
# Use Tongyi Qianwen
34+
tongyi_config = LLMConfig(
35+
model_type="tongyi",
36+
model_name=aimodel.name,
37+
api_key=aimodel.api_key,
38+
additional_params={"temperature": aimodel.temperature}
39+
)
40+
llm_service = LLMService(tongyi_config)
41+
42+
# Use Custom VLLM model
43+
""" vllm_config = LLMConfig(
44+
model_type="vllm",
45+
model_name="your-model-path",
46+
additional_params={
47+
"max_new_tokens": 200,
48+
"temperature": 0.3
49+
}
50+
)
51+
vllm_service = LLMService(vllm_config) """
52+
result = llm_service.generate_sql(question)
53+
return result

backend/apps/chat/schemas/__init__.py

Whitespace-only changes.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from pydantic import BaseModel
2+
from typing import Optional, Dict, Any, Type
3+
from abc import ABC, abstractmethod
4+
from langchain_core.language_models import BaseLLM as LangchainBaseLLM
5+
from langchain_openai import ChatOpenAI
6+
from langchain_community.llms import Tongyi, VLLM
7+
8+
class LLMConfig(BaseModel):
9+
"""Base configuration class for large language models"""
10+
model_type: str # Model type: openai/tongyi/vllm etc.
11+
model_name: str # Specific model name
12+
api_key: Optional[str] = None
13+
api_base_url: Optional[str] = None
14+
additional_params: Dict[str, Any] = {}
15+
16+
17+
class BaseLLM(ABC):
18+
"""Abstract base class for large language models"""
19+
20+
def __init__(self, config: LLMConfig):
21+
self.config = config
22+
self._llm = self._init_llm()
23+
24+
@abstractmethod
25+
def _init_llm(self) -> LangchainBaseLLM:
26+
"""Initialize specific large language model instance"""
27+
pass
28+
29+
@property
30+
def llm(self) -> LangchainBaseLLM:
31+
"""Return the langchain LLM instance"""
32+
return self._llm
33+
34+
class OpenAILLM(BaseLLM):
35+
def _init_llm(self) -> LangchainBaseLLM:
36+
return ChatOpenAI(
37+
model=self.config.model_name,
38+
api_key=self.config.api_key,
39+
**self.config.additional_params
40+
)
41+
42+
def generate(self, prompt: str) -> str:
43+
return self.llm.invoke(prompt)
44+
45+
class TongyiLLM(BaseLLM):
46+
def _init_llm(self) -> LangchainBaseLLM:
47+
return Tongyi(
48+
model_name=self.config.model_name,
49+
dashscope_api_key=self.config.api_key,
50+
**self.config.additional_params
51+
)
52+
53+
def generate(self, prompt: str) -> str:
54+
return self.llm.invoke(prompt)
55+
56+
class VLLMLLM(BaseLLM):
57+
def _init_llm(self) -> LangchainBaseLLM:
58+
return VLLM(
59+
model=self.config.model_name,
60+
**self.config.additional_params
61+
)
62+
63+
def generate(self, prompt: str) -> str:
64+
return self.llm.invoke(prompt)
65+
66+
67+
class LLMFactory:
68+
"""Large Language Model Factory Class"""
69+
70+
_llm_types: Dict[str, Type[BaseLLM]] = {
71+
"openai": OpenAILLM,
72+
"tongyi": TongyiLLM,
73+
"vllm": VLLMLLM
74+
}
75+
76+
@classmethod
77+
def create_llm(cls, config: LLMConfig) -> BaseLLM:
78+
llm_class = cls._llm_types.get(config.model_type)
79+
if not llm_class:
80+
raise ValueError(f"Unsupported LLM type: {config.model_type}")
81+
return llm_class(config)
82+
83+
@classmethod
84+
def register_llm(cls, model_type: str, llm_class: Type[BaseLLM]):
85+
"""Register new model type"""
86+
cls._llm_types[model_type] = llm_class
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
from pydantic import BaseModel
3+
4+
5+
class ChatQuestion(BaseModel):
6+
question: str

backend/apps/chat/schemas/llm.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from langchain_community.utilities import SQLDatabase
2+
# from langchain_community.agent_toolkits import create_sql_agent
3+
from langchain_community.llms import Tongyi
4+
from langchain_core.prompts import ChatPromptTemplate
5+
from apps.chat.schemas.chat_base_schema import LLMConfig, LLMFactory
6+
from common.core.config import settings
7+
import warnings
8+
9+
warnings.filterwarnings("ignore")
10+
11+
class LLMService:
12+
def __init__(self, config: LLMConfig):
13+
# Initialize database connection
14+
self.db = SQLDatabase.from_uri(str(settings.SQLALCHEMY_DATABASE_URI))
15+
16+
# Create LLM instance through factory
17+
llm_instance = LLMFactory.create_llm(config)
18+
self.llm = llm_instance.llm
19+
20+
# Define prompt template
21+
self.prompt = ChatPromptTemplate.from_messages([
22+
("system", """You are a professional SQL engineer. Generate PostgreSQL SELECT queries based on the database schema and user questions.
23+
Data modification or deletion is prohibited. Table structure is as follows:
24+
{schema}
25+
"""),
26+
("human", "{question}")
27+
])
28+
29+
def generate_sql(self, question: str) -> str:
30+
chain = self.prompt | self.llm
31+
schema = self.db.get_table_info()
32+
return chain.invoke({"schema": schema, "question": question})
33+

backend/common/core/response_middleware.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,14 @@ async def dispatch(self, request, call_next):
4949
)
5050
except Exception as e:
5151
logging.error(f"Response processing error: {str(e)}", exc_info=True)
52-
return response
52+
return JSONResponse(
53+
status_code=500,
54+
content={
55+
"code": 500,
56+
"data": None,
57+
"msg": str(e)
58+
}
59+
)
5360

5461
return response
5562

backend/main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from fastapi.staticfiles import StaticFiles
44
import os
55
import sentry_sdk
6-
from fastapi import FastAPI, Path
6+
from fastapi import FastAPI, Path, HTTPException
77
from fastapi.routing import APIRoute
88
from starlette.middleware.cors import CORSMiddleware
9+
from starlette.exceptions import HTTPException as StarletteHTTPException
910
from apps.api import api_router
1011
from apps.system.middleware.auth import TokenMiddleware
1112
from common.core.config import settings
12-
from common.core.response_middleware import ResponseMiddleware
13+
from common.core.response_middleware import ResponseMiddleware, exception_handler
1314

1415
def custom_generate_unique_id(route: APIRoute) -> str:
1516
tag = route.tags[0] if route.tags and len(route.tags) > 0 else ""
@@ -39,6 +40,9 @@ def custom_generate_unique_id(route: APIRoute) -> str:
3940
app.add_middleware(ResponseMiddleware)
4041
app.include_router(api_router, prefix=settings.API_V1_STR)
4142

43+
# Register exception handlers
44+
app.add_exception_handler(StarletteHTTPException, exception_handler.http_exception_handler)
45+
app.add_exception_handler(Exception, exception_handler.global_exception_handler)
4246

4347
frontend_dist = os.path.abspath("../frontend/dist")
4448
if not os.path.exists(frontend_dist):

0 commit comments

Comments
 (0)