Skip to content

Commit 684cb4f

Browse files
committed
Merge branch 'main' of https://github.com/dataease/SQLBot
2 parents ca57be2 + 2a11709 commit 684cb4f

File tree

6 files changed

+43
-11
lines changed

6 files changed

+43
-11
lines changed

backend/apps/ai_model/model_factory.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from apps.system.models.system_model import AiModelDetail
1212
from common.core.db import engine
1313
from common.utils.utils import prepare_model_arg
14-
14+
from langchain_community.llms import VLLMOpenAI
1515

1616
# from langchain_community.llms import Tongyi, VLLM
1717

@@ -60,7 +60,14 @@ def llm(self) -> BaseChatModel:
6060
"""Return the langchain LLM instance"""
6161
return self._llm
6262

63-
63+
class OpenAIvLLM(BaseLLM):
64+
def _init_llm(self) -> VLLMOpenAI:
65+
return VLLMOpenAI(
66+
openai_api_key=self.config.api_key or 'Empty',
67+
openai_api_base=self.config.api_base_url,
68+
model_name=self.config.model_name,
69+
**self.config.additional_params,
70+
)
6471
class OpenAILLM(BaseLLM):
6572
def _init_llm(self) -> BaseChatModel:
6673
return BaseChatOpenAI(
@@ -81,7 +88,7 @@ class LLMFactory:
8188
_llm_types: Dict[str, Type[BaseLLM]] = {
8289
"openai": OpenAILLM,
8390
"tongyi": OpenAILLM,
84-
"vllm": OpenAILLM
91+
"vllm": OpenAIvLLM
8592
}
8693

8794
@classmethod
@@ -129,7 +136,7 @@ def get_default_config() -> LLMConfig:
129136
# 构造 LLMConfig
130137
return LLMConfig(
131138
model_id=db_model.id,
132-
model_type="openai",
139+
model_type="openai" if db_model.protocol == 1 else "vllm",
133140
model_name=db_model.base_model,
134141
api_key=db_model.api_key,
135142
api_base_url=db_model.api_domain,

backend/apps/system/api/aimodel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ async def generate():
2020
try:
2121
additional_params = {item.key: prepare_model_arg(item.val) for item in info.config_list if item.key and item.val}
2222
config = LLMConfig(
23-
model_type="openai",
23+
model_type="openai" if info.protocol == 1 else "vllm",
2424
model_name=info.base_model,
2525
api_key=info.api_key,
2626
api_base_url=info.api_domain,
@@ -74,7 +74,8 @@ async def query(
7474
AiModelDetail.name,
7575
AiModelDetail.model_type,
7676
AiModelDetail.base_model,
77-
AiModelDetail.supplier,
77+
AiModelDetail.supplier,
78+
AiModelDetail.protocol,
7879
AiModelDetail.default_model)
7980
if keyword is not None:
8081
statement = statement.where(AiModelDetail.name.like(f"%{keyword}%"))

backend/apps/system/schemas/ai_model_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class AiModelItem(BaseModel):
99
model_type: int
1010
base_model: str
1111
supplier: int
12+
protocol: int
1213
default_model: bool = False
1314

1415
class AiModelGridItem(AiModelItem, BaseCreatorDTO):
4.6 KB
Loading

frontend/src/entity/supplier.ts

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,25 @@ import icon_openai_colorful from '@/assets/model/icon_openai_colorful.png'
88
import icon_kimi_colorful from '@/assets/model/icon_kimi_colorful.png'
99
import icon_txhy_colorful from '@/assets/model/icon_txhy_colorful.png'
1010
import icon_hsyq_colorful from '@/assets/model/icon_hsyq_colorful.png'
11+
import icon_vllm_colorful from '@/assets/model/icon_vllm_colorful.png'
1112

1213
type ModelArg = { key: string; val?: string | number; type: string; range?: string }
1314
type ModelOption = { name: string; api_domain?: string; args?: ModelArg[] }
1415
type ModelConfig = Record<
1516
number,
16-
{ api_domain: string; common_args?: ModelArg[]; model_options: ModelOption[] }
17+
{
18+
api_domain: string
19+
common_args?: ModelArg[]
20+
model_options: ModelOption[]
21+
}
1722
>
1823

1924
export const supplierList: Array<{
2025
id: number
2126
name: string
2227
icon: any
28+
type?: string
29+
is_private?: boolean
2330
model_config: ModelConfig
2431
}> = [
2532
{
@@ -259,6 +266,20 @@ export const supplierList: Array<{
259266
},
260267
},
261268
},
269+
{
270+
id: 11,
271+
name: 'vLLM',
272+
icon: icon_vllm_colorful,
273+
type: 'vllm',
274+
is_private: true,
275+
model_config: {
276+
0: {
277+
api_domain: 'http://127.0.0.1:8000/v1',
278+
common_args: [{ key: 'temperature', val: 0.6, type: 'number', range: '[0, 1]' }],
279+
model_options: [],
280+
},
281+
},
282+
},
262283
]
263284

264285
export const base_model_options = (supplier_id: number, model_type?: number) => {

frontend/src/views/system/model/ModelForm.vue

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ const modelForm = reactive({
3838
api_key: '',
3939
api_domain: '',
4040
config_list: [],
41+
protocol: 1,
4142
})
4243
const isCreate = ref(false)
4344
const modelRef = ref()
@@ -90,7 +91,7 @@ const handleCurrentChange = (val: any) => {
9091
currentPage.value = val
9192
}
9293
93-
const rules = {
94+
const rules = computed(() => ({
9495
model_type: [
9596
{
9697
required: true,
@@ -109,12 +110,12 @@ const rules = {
109110
name: [{ required: true, message: t('model.the_basic_model'), trigger: 'blur' }],
110111
api_key: [
111112
{
112-
required: true,
113+
required: !currentSupplier.value?.is_private,
113114
message: t('datasource.please_enter') + t('common.empty') + 'API Key',
114115
trigger: 'blur',
115116
},
116117
],
117-
}
118+
}))
118119
119120
onMounted(() => {
120121
setTimeout(() => {
@@ -169,6 +170,7 @@ const supplierChang = (supplier: any) => {
169170
const config = supplier.model_config[modelForm.model_type || 0]
170171
modelForm.api_domain = config.api_domain
171172
modelForm.base_model = ''
173+
modelForm.protocol = supplier.type === 'vllm' ? 2 : 1
172174
advancedSetting.value = []
173175
}
174176
let curId = +new Date()
@@ -201,7 +203,7 @@ const formatAdvancedSetting = (list: Array<any>) => {
201203
advancedSetting.value = setting_list
202204
}
203205
const baseModelChange = (val: string) => {
204-
if (!val || !modelForm.supplier || !modelList.value?.length) {
206+
if (!val || !modelForm.supplier) {
205207
return
206208
}
207209
const current_model = modelList.value?.find((model: any) => model.name == val)

0 commit comments

Comments
 (0)