Skip to content

Commit 8fe1a14

Browse files
authored
feat: 支持讯飞星火大模型 #181 (#184)
1 parent 4aa1b58 commit 8fe1a14

File tree

6 files changed

+158
-0
lines changed

6 files changed

+158
-0
lines changed

apps/setting/models_provider/constants/model_provider_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
1515
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
1616
from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider
17+
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
1718
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
1819

1920

@@ -25,3 +26,4 @@ class ModelProvideConstants(Enum):
2526
model_kimi_provider = KimiModelProvider()
2627
model_qwen_provider = QwenModelProvider()
2728
model_zhipu_provider = ZhiPuModelProvider()
29+
model_xf_provider = XunFeiModelProvider()
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: __init__.py.py
6+
@date:2024/04/19 15:55
7+
@desc:
8+
"""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
<svg t="1713509569091" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="4361" xmlns:xlink="http://www.w3.org/1999/xlink" width="100%" height="100%" ><path d="M500.1216 971.40736c-55.58272-4.28032-102.11328-16.7936-147.74272-39.7312-115.0976-57.83552-192.88064-168.5504-208.81408-297.12384-2.8672-23.18336-2.90816-69.18144-0.06144-93.3888a387.95264 387.95264 0 0 1 82.65728-196.46464c6.22592-7.7824 47.616-49.7664 94.49472-95.92832 112.8448-111.104 111.53408-109.85472 113.29536-108.09344 1.024 1.024 0.75776 5.59104-0.8192 14.45888-3.13344 17.57184-3.13344 55.99232 0.04096 77.0048 9.66656 64.49152 37.66272 124.3136 87.16288 186.32704 12.6976 15.91296 59.22816 63.24224 76.57472 77.88544 20.13184 16.9984 33.1776 37.53984 39.34208 61.8496 4.07552 16.0768 4.07552 42.10688 0 58.20416-10.24 40.57088-40.8576 72.58112-81.6128 85.38112-9.35936 2.92864-13.84448 3.39968-32.68608 3.39968s-23.3472-0.47104-32.768-3.39968c-29.02016-9.07264-56.40192-30.06464-32.52224-24.94464 12.94336 2.7648 29.65504-3.2768 37.49888-13.57824 10.81344-14.1312 12.57472-29.53216 5.09952-44.48256-3.76832-7.53664-6.8608-10.91584-19.12832-20.82816-33.1776-26.86976-65.7408-59.5968-88.8832-89.25184-11.81696-15.17568-28.8768-40.59136-33.95584-50.5856-1.92512-3.7888-4.15744-6.90176-4.95616-6.90176-2.00704 0-17.92 24.43264-24.73984 37.96992-7.84384 15.52384-15.33952 37.888-19.12832 57.0368-4.64896 23.3472-4.64896 59.14624-0.04096 82.16576 17.77664 88.63744 82.78016 154.74688 171.02848 173.8752 12.3904 2.70336 19.39456 3.23584 42.496 3.23584 23.08096 0 30.1056-0.53248 42.47552-3.23584 43.04896-9.3184 78.4384-28.672 109.6704-59.904 32.72704-32.72704 52.26496-69.44768 61.29664-115.3024 4.54656-23.06048 4.56704-56.36096 0.04096-79.29856-3.87072-19.5584-8.76544-35.16416-15.85152-50.3808-6.69696-14.35648-6.0416-15.21664 9.1136-12.0832 40.89856 8.45824 85.6064 31.41632 114.40128 58.75712 34.6112 32.84992 49.27488 65.45408 49.27488 109.71136 0 24.00256-3.4816 41.6768-13.35296 68.17792-20.54144 54.9888-50.54464 100.61824-93.7984 142.66368-51.26144 49.80736-116.8384 85.03296-183.95136 98.79552-30.45376 6.2464-76.53376 9.89184-101.1712 7.9872z m391.53664-433.93024c-17.05984-32.93184-41.75872-56.48384-76.8-73.19552-18.80064-8.97024-35.67616-14.52032-68.75136-22.58944-44.46208-10.8544-66.2528-18.2272-93.16352-31.62112-26.2144-13.04576-46.16192-27.3408-66.3552-47.5136-26.70592-26.74688-42.63936-52.4288-54.35392-87.67488-10.36288-31.27296-10.0352-27.2384-10.62912-128.96256-0.45056-78.37696-0.2048-92.0576 1.51552-92.0576 1.1264 0 45.8752 42.98752 99.40992 95.51872 190.95552 187.37152 194.58048 191.11936 216.7808 224.78848 20.13184 30.53568 39.26016 72.0896 48.76288 105.92256 4.7104 16.73216 11.0592 48.18944 11.81696 58.40896 0.7168 9.8304-2.84672 9.35936-8.23296-1.024z" fill="#3DC8F9" p-id="4362"></path><path d="M523.12064 53.8624c-1.7408 0-1.96608 13.68064-1.51552 92.0576 0.57344 101.74464 0.24576 97.6896 10.6496 128.96256 11.6736 35.2256 27.62752 60.928 54.33344 87.6544 20.19328 20.19328 40.1408 34.48832 66.3552 47.5136 26.91072 13.4144 48.70144 20.80768 93.14304 31.6416 33.09568 8.0896 49.9712 13.6192 68.75136 22.58944 35.04128 16.71168 59.74016 40.2432 76.8 73.19552 5.40672 10.38336 8.97024 10.8544 8.25344 1.024-0.75776-10.21952-7.12704-41.6768-11.81696-58.40896-9.50272-33.83296-28.63104-75.3664-48.76288-105.92256-22.20032-33.66912-25.82528-37.41696-216.7808-224.78848-53.5552-52.5312-98.304-95.51872-99.40992-95.51872z" fill="#EA0100" p-id="4363"></path><path d="M391.3728 762.30656s86.2208 88.41216 241.9712 100.06528c155.7504 11.63264 193.536-45.44512 193.536-45.44512s76.3904-100.864 65.08544-177.54112c-11.32544-76.67712-71.02464-131.21536-174.96064-154.89024 0 0 31.90784 80.2816 20.5824 128.14336-11.32544 47.86176-20.0704 138.93632-159.5392 186.7776 0 0-102.68672 30.208-186.6752-37.10976z" fill="#1652D8" p-id="4364"></path></svg>
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: __init__.py.py
6+
@date:2024/04/19 15:55
7+
@desc:
8+
"""
9+
10+
from typing import List, Optional, Any, Iterator
11+
12+
from langchain_community.chat_models import ChatSparkLLM
13+
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
14+
from langchain_core.callbacks import CallbackManagerForLLMRun
15+
from langchain_core.messages import BaseMessage, AIMessageChunk
16+
from langchain_core.outputs import ChatGenerationChunk
17+
18+
19+
class XFChatSparkLLM(ChatSparkLLM):
20+
def _stream(
21+
self,
22+
messages: List[BaseMessage],
23+
stop: Optional[List[str]] = None,
24+
run_manager: Optional[CallbackManagerForLLMRun] = None,
25+
**kwargs: Any,
26+
) -> Iterator[ChatGenerationChunk]:
27+
default_chunk_class = AIMessageChunk
28+
29+
self.client.arun(
30+
[_convert_message_to_dict(m) for m in messages],
31+
self.spark_user_id,
32+
self.model_kwargs,
33+
True,
34+
)
35+
for content in self.client.subscribe(timeout=self.request_timeout):
36+
if "data" not in content:
37+
continue
38+
delta = content["data"]
39+
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
40+
cg_chunk = ChatGenerationChunk(message=chunk)
41+
if run_manager:
42+
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
43+
yield cg_chunk
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: xf_model_provider.py
6+
@date:2024/04/19 14:47
7+
@desc:
8+
"""
9+
import os
10+
from typing import Dict
11+
12+
from langchain.schema import HumanMessage
13+
from langchain_community.chat_models import ChatSparkLLM
14+
15+
from common import forms
16+
from common.exception.app_exception import AppApiException
17+
from common.forms import BaseForm
18+
from common.util.file_util import get_file_content
19+
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
20+
ModelInfo, IModelProvider, ValidCode
21+
from setting.models_provider.impl.xf_model_provider.model.xf_chat_model import XFChatSparkLLM
22+
from smartdoc.conf import PROJECT_DIR
23+
import ssl
24+
25+
ssl._create_default_https_context = ssl.create_default_context()
26+
27+
28+
class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
29+
30+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
31+
model_type_list = XunFeiModelProvider().get_model_type_list()
32+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
33+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
34+
35+
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
36+
if key not in model_credential:
37+
if raise_exception:
38+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
39+
else:
40+
return False
41+
try:
42+
model = XunFeiModelProvider().get_model(model_type, model_name, model_credential)
43+
model.invoke([HumanMessage(content='你好')])
44+
except Exception as e:
45+
if isinstance(e, AppApiException):
46+
raise e
47+
if raise_exception:
48+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
49+
else:
50+
return False
51+
return True
52+
53+
def encryption_dict(self, model: Dict[str, object]):
54+
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
55+
56+
spark_api_url = forms.TextInputField('API 域名', required=True)
57+
spark_app_id = forms.TextInputField('APP ID', required=True)
58+
spark_api_key = forms.PasswordInputField("API Key", required=True)
59+
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
60+
61+
62+
qwen_model_credential = XunFeiLLMModelCredential()
63+
64+
model_dict = {
65+
'generalv3.5': ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential),
66+
'generalv3': ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential),
67+
'generalv2': ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential)
68+
}
69+
70+
71+
class XunFeiModelProvider(IModelProvider):
72+
73+
def get_dialogue_number(self):
74+
return 3
75+
76+
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> XFChatSparkLLM:
77+
zhipuai_chat = XFChatSparkLLM(
78+
spark_app_id=model_credential.get('spark_app_id'),
79+
spark_api_key=model_credential.get('spark_api_key'),
80+
spark_api_secret=model_credential.get('spark_api_secret'),
81+
spark_api_url=model_credential.get('spark_api_url'),
82+
spark_llm_domain=model_name
83+
)
84+
return zhipuai_chat
85+
86+
def get_model_credential(self, model_type, model_name):
87+
if model_name in model_dict:
88+
return model_dict.get(model_name).model_credential
89+
return qwen_model_credential
90+
91+
def get_model_provide_info(self):
92+
return ModelProvideInfo(provider='model_xf_provider', name='讯飞星火', icon=get_file_content(
93+
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xf_model_provider', 'icon',
94+
'xf_icon_svg')))
95+
96+
def get_model_list(self, model_type: str):
97+
if model_type is None:
98+
raise AppApiException(500, '模型类型不能为空')
99+
return [model_dict.get(key).to_dict() for key in
100+
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
101+
102+
def get_model_type_list(self):
103+
return [{'key': "大语言模型", 'value': "LLM"}]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dashscope = "^1.17.0"
3636
zhipuai = "^2.0.1"
3737
httpx = "^0.27.0"
3838
httpx-sse = "^0.4.0"
39+
websocket-client = "^1.7.0"
3940

4041
[build-system]
4142
requires = ["poetry-core"]

0 commit comments

Comments
 (0)