Skip to content

Commit 8cdb085

Browse files
committed
feat: add support for v2 API version in embedding models and update validation logic
1 parent 20cf018 commit 8cdb085

File tree

3 files changed

+88
-14
lines changed

3 files changed

+88
-14
lines changed

apps/models_provider/impl/openai_model_provider/model/llm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import List, Dict
1010

1111
from langchain_core.messages import BaseMessage, get_buffer_string
12-
from langchain_openai.chat_models import ChatOpenAI
1312

1413
from common.config.tokenizer_manage_config import TokenizerManage
1514
from models_provider.base_model_provider import MaxKBBaseModel

apps/models_provider/impl/wenxin_model_provider/credential/embedding.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,27 @@ class QianfanEmbeddingCredential(BaseForm, BaseModelCredential):
2121

2222
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
2323
raise_exception=False):
24-
model_type_list = provider.get_model_type_list()
25-
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
26-
raise AppApiException(ValidCode.valid_error.value,
27-
_('{model_type} Model type is not supported').format(model_type=model_type))
28-
self.valid_form(model_credential)
24+
api_version = model_credential.get('api_version', 'v1')
25+
model = provider.get_model(model_type, model_name, model_credential, **model_params)
26+
if api_version == 'v1':
27+
model_type_list = provider.get_model_type_list()
28+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
29+
raise AppApiException(ValidCode.valid_error.value,
30+
_('{model_type} Model type is not supported').format(model_type=model_type))
31+
model_info = [model.lower() for model in model.client.models()]
32+
if not model_info.__contains__(model_name.lower()):
33+
raise AppApiException(ValidCode.valid_error.value,
34+
_('{model_name} The model does not support').format(model_name=model_name))
35+
required_keys = ['qianfan_ak', 'qianfan_sk']
36+
if api_version == 'v2':
37+
required_keys = ['api_base', 'qianfan_ak']
38+
39+
for key in required_keys:
40+
if key not in model_credential:
41+
if raise_exception:
42+
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
43+
else:
44+
return False
2945
try:
3046
model = provider.get_model(model_type, model_name, model_credential)
3147
model.embed_query(_('Hello'))
@@ -42,8 +58,25 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
4258
return True
4359

4460
def encryption_dict(self, model: Dict[str, object]):
45-
return {**model, 'qianfan_sk': super().encryption(model.get('qianfan_sk', ''))}
61+
api_version = model.get('api_version', 'v1')
62+
if api_version == 'v1':
63+
return {**model, 'qianfan_sk': super().encryption(model.get('qianfan_sk', ''))}
64+
else: # v2
65+
return {**model, 'qianfan_ak': super().encryption(model.get('qianfan_ak', ''))}
4666

47-
qianfan_ak = forms.PasswordInputField('API Key', required=True)
67+
api_version = forms.Radio('API Version', required=True, text_field='label', value_field='value',
68+
option_list=[
69+
{'label': 'v1', 'value': 'v1'},
70+
{'label': 'v2', 'value': 'v2'}
71+
],
72+
default_value='v1',
73+
provider='',
74+
method='', )
75+
76+
# v2版本字段
77+
api_base = forms.TextInputField("API URL", required=True, relation_show_field_dict={"api_version": ["v2"]})
4878

49-
qianfan_sk = forms.PasswordInputField("Secret Key", required=True)
79+
# v1版本字段
80+
qianfan_ak = forms.PasswordInputField('API Key', required=True)
81+
qianfan_sk = forms.PasswordInputField("Secret Key", required=True,
82+
relation_show_field_dict={"api_version": ["v1"]})

apps/models_provider/impl/wenxin_model_provider/model/embedding.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,60 @@
66
@date:2024/10/17 16:48
77
@desc:
88
"""
9-
from typing import Dict
10-
9+
from typing import Dict, List
1110
from langchain_community.embeddings import QianfanEmbeddingsEndpoint
12-
11+
import openai
1312
from models_provider.base_model_provider import MaxKBBaseModel
1413

1514

16-
class QianfanEmbeddings(MaxKBBaseModel, QianfanEmbeddingsEndpoint):
15+
class QianfanV1Embeddings(MaxKBBaseModel, QianfanEmbeddingsEndpoint):
1716
@staticmethod
1817
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
19-
return QianfanEmbeddings(
18+
return QianfanV1Embeddings(
2019
model=model_name,
2120
qianfan_ak=model_credential.get('qianfan_ak'),
2221
qianfan_sk=model_credential.get('qianfan_sk'),
2322
)
23+
24+
25+
class QianfanV2EmbeddingModel(MaxKBBaseModel):
26+
model_name: str
27+
28+
@staticmethod
29+
def is_cache_model():
30+
return False
31+
32+
def __init__(self, api_key, base_url, model_name: str):
33+
self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings
34+
self.model_name = model_name
35+
36+
@staticmethod
37+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
38+
return QianfanV2EmbeddingModel(
39+
api_key=model_credential.get('qianfan_ak'),
40+
model_name=model_name,
41+
base_url=model_credential.get('api_base'),
42+
)
43+
44+
def embed_query(self, text: str):
45+
res = self.embed_documents([text])
46+
return res[0]
47+
48+
def embed_documents(
49+
self, texts: List[ str],
50+
) -> List[List[float]]:
51+
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
52+
return [e.embedding for e in res.data]
53+
54+
55+
class QianfanEmbeddings(MaxKBBaseModel):
56+
57+
58+
@staticmethod
59+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
60+
api_version = model_credential.get('api_version', 'v1')
61+
62+
if api_version == "v1":
63+
return QianfanV1Embeddings.new_instance(model_type, model_name, model_credential, **model_kwargs)
64+
elif api_version == "v2":
65+
return QianfanV2EmbeddingModel.new_instance(model_type, model_name, model_credential, **model_kwargs)

0 commit comments

Comments
 (0)