Skip to content

Commit bcfa22c

Browse files
committed
feat: support siliconCloud rerank
1 parent 8957b77 commit bcfa22c

File tree

17 files changed

+100
-29
lines changed

17 files changed

+100
-29
lines changed

apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tti import QwenTextToImageModel
2929
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tts import AliyunBaiLianTextToSpeech
3030
from smartdoc.conf import PROJECT_DIR
31-
from django.utils.translation import gettext_lazy as _, gettext
31+
from django.utils.translation import gettext as _, gettext
3232

3333
aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential()
3434
aliyun_bai_lian_tts_model_credential = AliyunBaiLianTTSModelCredential()

apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
1212
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
1313
from smartdoc.conf import PROJECT_DIR
14-
from django.utils.translation import gettext_lazy as _
14+
from django.utils.translation import gettext as _
1515

1616

1717
def _create_model_info(model_name, description, model_type, credential_class, model_class):

apps/setting/models_provider/impl/deepseek_model_provider/deepseek_model_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from setting.models_provider.impl.deepseek_model_provider.credential.llm import DeepSeekLLMModelCredential
1515
from setting.models_provider.impl.deepseek_model_provider.model.llm import DeepSeekChatModel
1616
from smartdoc.conf import PROJECT_DIR
17-
from django.utils.translation import gettext_lazy as _
17+
from django.utils.translation import gettext as _
1818
deepseek_llm_model_credential = DeepSeekLLMModelCredential()
1919

2020
deepseek_chat = ModelInfo('deepseek-chat', _('Good at common conversational tasks, supports 32K contexts'), ModelTypeConst.LLM,

apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
2121
from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText
2222
from smartdoc.conf import PROJECT_DIR
23-
from django.utils.translation import gettext_lazy as _
23+
from django.utils.translation import gettext as _
2424

2525

2626
gemini_llm_model_credential = GeminiLLMModelCredential()

apps/setting/models_provider/impl/local_model_provider/local_model_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
1717
from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker
1818
from smartdoc.conf import PROJECT_DIR
19-
from django.utils.translation import gettext_lazy as _
19+
from django.utils.translation import gettext as _
2020

2121
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
2222
LocalEmbeddingCredential(), LocalEmbedding)

apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from setting.models_provider.impl.ollama_model_provider.model.image import OllamaImage
2323
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
2424
from smartdoc.conf import PROJECT_DIR
25-
from django.utils.translation import gettext_lazy as _
25+
from django.utils.translation import gettext as _
2626

2727
""
2828

apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
2020
from setting.models_provider.impl.qwen_model_provider.model.tti import QwenTextToImageModel
2121
from smartdoc.conf import PROJECT_DIR
22-
from django.utils.translation import gettext_lazy as _
22+
from django.utils.translation import gettext as _
2323

2424
qwen_model_credential = OpenAILLMModelCredential()
2525
qwenvl_model_credential = QwenVLModelCredential()

apps/setting/models_provider/impl/siliconCloud_model_provider/credential/reranker.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from common.exception.app_exception import AppApiException
1616
from common.forms import BaseForm
1717
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
18-
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
1918
from setting.models_provider.impl.siliconCloud_model_provider.model.reranker import SiliconCloudReranker
2019

2120

@@ -26,7 +25,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
2625
if not model_type == 'RERANKER':
2726
raise AppApiException(ValidCode.valid_error.value,
2827
_('{model_type} Model type is not supported').format(model_type=model_type))
29-
for key in ['dashscope_api_key']:
28+
for key in ['api_base', 'api_key']:
3029
if key not in model_credential:
3130
if raise_exception:
3231
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
@@ -47,6 +46,6 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
4746
return True
4847

4948
def encryption_dict(self, model: Dict[str, object]):
50-
return {**model, 'dashscope_api_key': super().encryption(model.get('dashscope_api_key', ''))}
51-
52-
dashscope_api_key = forms.PasswordInputField('API Key', required=True)
49+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
50+
api_base = forms.TextInputField('API URL', required=True)
51+
api_key = forms.PasswordInputField('API Key', required=True)

apps/setting/models_provider/impl/siliconCloud_model_provider/model/reranker.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,92 @@
22
"""
33
@project: MaxKB
44
@Author:虎
5-
@file: reranker.py.py
6-
@date:2024/9/2 16:42
7-
@desc:
5+
@file: siliconcloud_reranker.py
6+
@date:2024/9/10 9:45
7+
@desc: SiliconCloud 文档重排封装
88
"""
9-
from typing import Dict
109

11-
from langchain_community.document_compressors import DashScopeRerank
10+
from typing import Sequence, Optional, Any, Dict
11+
import requests
12+
13+
from langchain_core.callbacks import Callbacks
14+
from langchain_core.documents import BaseDocumentCompressor, Document
1215

1316
from setting.models_provider.base_model_provider import MaxKBBaseModel
17+
from django.utils.translation import gettext as _
18+
1419

20+
class SiliconCloudReranker(MaxKBBaseModel, BaseDocumentCompressor):
21+
api_base: Optional[str]
22+
"""SiliconCloud API URL"""
23+
model: Optional[str]
24+
"""SiliconCloud 重排模型 ID"""
25+
api_key: Optional[str]
26+
"""API Key"""
27+
28+
top_n: Optional[int] = 3 # 取前 N 个最相关的结果
1529

16-
class SiliconCloudReranker(MaxKBBaseModel, DashScopeRerank):
1730
@staticmethod
1831
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
19-
return SiliconCloudReranker(model=model_name, dashscope_api_key=model_credential.get('dashscope_api_key'),
20-
top_n=model_kwargs.get('top_n', 3))
32+
return SiliconCloudReranker(
33+
api_base=model_credential.get('api_base'),
34+
model=model_name,
35+
api_key=model_credential.get('api_key'),
36+
top_n=model_kwargs.get('top_n', 3)
37+
)
38+
39+
def __init__(
40+
self, api_base: Optional[str] = None, model: Optional[str] = None, top_n=3,
41+
api_key: Optional[str] = None
42+
):
43+
super().__init__()
44+
45+
if not api_base:
46+
raise ValueError(_('Please provide server URL'))
47+
48+
if not model:
49+
raise ValueError(_('Please provide the model'))
50+
51+
if not api_key:
52+
raise ValueError(_('Please provide the API Key'))
53+
54+
self.api_base = api_base
55+
self.model = model
56+
self.api_key = api_key
57+
self.top_n = top_n
58+
59+
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
60+
Sequence[Document]:
61+
if not documents:
62+
return []
63+
64+
# 预处理文本
65+
texts = [doc.page_content for doc in documents]
66+
67+
# 发送请求到 SiliconCloud API
68+
headers = {
69+
"Authorization": f"Bearer {self.api_key}",
70+
"Content-Type": "application/json"
71+
}
72+
payload = {
73+
"model": self.model,
74+
"query": query,
75+
"documents": texts,
76+
"top_n": self.top_n
77+
}
78+
79+
response = requests.post(f"{self.api_base}/rerank", json=payload, headers=headers)
80+
81+
if response.status_code != 200:
82+
raise RuntimeError(f"SiliconCloud API 请求失败: {response.text}")
83+
84+
res = response.json()
85+
86+
# 解析返回结果
87+
return [
88+
Document(
89+
page_content=item.get('document', {}).get('text', ''),
90+
metadata={'relevance_score': item.get('relevance_score')}
91+
)
92+
for item in res.get('results', [])
93+
]

apps/setting/models_provider/impl/siliconCloud_model_provider/siliconCloud_model_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from setting.models_provider.impl.siliconCloud_model_provider.model.stt import SiliconCloudSpeechToText
2525
from setting.models_provider.impl.siliconCloud_model_provider.model.tti import SiliconCloudTextToImage
2626
from smartdoc.conf import PROJECT_DIR
27-
from django.utils.translation import gettext_lazy as _
27+
from django.utils.translation import gettext as _
2828

2929
openai_llm_model_credential = SiliconCloudLLMModelCredential()
3030
openai_stt_model_credential = SiliconCloudSTTModelCredential()

0 commit comments

Comments
 (0)