Skip to content

Commit 044465f

Browse files
committed
feat: enhance model credential validation and support for multiple API versions
1 parent 795db14 commit 044465f

File tree

5 files changed

+98
-28
lines changed

5 files changed

+98
-28
lines changed

apps/application/flow/tools.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,20 @@ def get_reasoning_content(self, chunk):
6060
if not self.reasoning_content_is_end:
6161
self.reasoning_content_is_end = True
6262
self.content += self.all_content
63-
return {'content': self.all_content, 'reasoning_content': ''}
63+
return {'content': self.all_content,
64+
'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
65+
'') if chunk.additional_kwargs else ''
66+
}
6467
else:
6568
if self.reasoning_content_is_start:
6669
self.reasoning_content_chunk += chunk.content
6770
reasoning_content_end_tag_prefix_index = self.reasoning_content_chunk.find(
6871
self.reasoning_content_end_tag_prefix)
6972
if self.reasoning_content_is_end:
7073
self.content += chunk.content
71-
return {'content': chunk.content, 'reasoning_content': ''}
74+
return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
75+
'') if chunk.additional_kwargs else ''
76+
}
7277
# 是否包含结束
7378
if reasoning_content_end_tag_prefix_index > -1:
7479
if len(self.reasoning_content_chunk) - reasoning_content_end_tag_prefix_index >= self.reasoning_content_end_tag_len:
@@ -93,7 +98,9 @@ def get_reasoning_content(self, chunk):
9398
else:
9499
if self.reasoning_content_is_end:
95100
self.content += chunk.content
96-
return {'content': chunk.content, 'reasoning_content': ''}
101+
return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
102+
'') if chunk.additional_kwargs else ''
103+
}
97104
else:
98105
# aaa
99106
result = {'content': '', 'reasoning_content': self.reasoning_content_chunk}

apps/common/forms/radio_button_field.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from common.forms.base_field import BaseExecField, TriggerType
1212

1313

14-
class Radio(BaseExecField):
14+
class RadioButton(BaseExecField):
1515
"""
1616
下拉单选
1717
"""

apps/common/forms/radio_card_field.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from common.forms.base_field import BaseExecField, TriggerType
1212

1313

14-
class Radio(BaseExecField):
14+
class RadioCard(BaseExecField):
1515
"""
1616
下拉单选
1717
"""

apps/models_provider/impl/wenxin_model_provider/credential/llm.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,23 @@ class WenxinLLMModelParams(BaseForm):
4040
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
4141
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
4242
raise_exception=False):
43-
model_type_list = provider.get_model_type_list()
44-
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
45-
raise AppApiException(ValidCode.valid_error.value,
46-
gettext('{model_type} Model type is not supported').format(model_type=model_type))
43+
# 根据api_version检查必需字段
44+
api_version = model_credential.get('api_version', 'v1')
4745
model = provider.get_model(model_type, model_name, model_credential, **model_params)
48-
model_info = [model.lower() for model in model.client.models()]
49-
if not model_info.__contains__(model_name.lower()):
50-
raise AppApiException(ValidCode.valid_error.value,
51-
gettext('{model_name} The model does not support').format(model_name=model_name))
52-
for key in ['api_key', 'secret_key']:
46+
if api_version == 'v1':
47+
model_type_list = provider.get_model_type_list()
48+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
49+
raise AppApiException(ValidCode.valid_error.value,
50+
gettext('{model_type} Model type is not supported').format(model_type=model_type))
51+
model_info = [model.lower() for model in model.client.models()]
52+
if not model_info.__contains__(model_name.lower()):
53+
raise AppApiException(ValidCode.valid_error.value,
54+
gettext('{model_name} The model does not support').format(model_name=model_name))
55+
required_keys = ['api_key', 'secret_key']
56+
if api_version == 'v2':
57+
required_keys = ['api_base', 'api_key']
58+
59+
for key in required_keys:
5360
if key not in model_credential:
5461
if raise_exception:
5562
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
@@ -64,19 +71,47 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
6471
return True
6572

6673
def encryption_dict(self, model_info: Dict[str, object]):
67-
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
74+
# 根据api_version加密不同字段
75+
api_version = model_info.get('api_version', 'v1')
76+
if api_version == 'v1':
77+
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
78+
else: # v2
79+
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
6880

6981
def build_model(self, model_info: Dict[str, object]):
70-
for key in ['api_key', 'secret_key', 'model']:
71-
if key not in model_info:
72-
raise AppApiException(500, gettext('{key} is required').format(key=key))
73-
self.api_key = model_info.get('api_key')
74-
self.secret_key = model_info.get('secret_key')
82+
api_version = model_info.get('api_version', 'v1')
83+
# 根据api_version检查必需字段
84+
if api_version == 'v1':
85+
for key in ['api_version', 'api_key', 'secret_key', 'model']:
86+
if key not in model_info:
87+
raise AppApiException(500, gettext('{key} is required').format(key=key))
88+
self.api_key = model_info.get('api_key')
89+
self.secret_key = model_info.get('secret_key')
90+
else: # v2
91+
for key in ['api_version', 'api_base', 'api_key', 'model', ]:
92+
if key not in model_info:
93+
raise AppApiException(500, gettext('{key} is required').format(key=key))
94+
self.api_base = model_info.get('api_base')
95+
self.api_key = model_info.get('api_key')
7596
return self
7697

77-
api_key = forms.PasswordInputField('API Key', required=True)
98+
# 动态字段定义 - 根据api_version显示不同字段
99+
api_version = forms.Radio('API Version', required=True, text_field='label', value_field='value',
100+
option_list=[
101+
{'label': 'v1', 'value': 'v1'},
102+
{'label': 'v2', 'value': 'v2'}
103+
],
104+
default_value='v1',
105+
provider='',
106+
method='', )
107+
108+
# v2版本字段
109+
api_base = forms.TextInputField("API Base", required=False, relation_show_field_dict={"api_version": ["v2"]})
78110

79-
secret_key = forms.PasswordInputField("Secret Key", required=True)
111+
# v1版本字段
112+
api_key = forms.PasswordInputField('API Key', required=False)
113+
secret_key = forms.PasswordInputField("Secret Key", required=False,
114+
relation_show_field_dict={"api_version": ["v1"]})
80115

81116
def get_model_params_setting_form(self, model_name):
82117
return WenxinLLMModelParams()

apps/models_provider/impl/wenxin_model_provider/model/llm.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,22 @@
1717
from langchain_core.outputs import ChatGenerationChunk
1818

1919
from models_provider.base_model_provider import MaxKBBaseModel
20+
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
2021

2122

22-
class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
23+
class QianfanChatModelQianfan(MaxKBBaseModel, QianfanChatEndpoint):
2324
@staticmethod
2425
def is_cache_model():
2526
return False
2627

2728
@staticmethod
2829
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
2930
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
30-
return QianfanChatModel(model=model_name,
31-
qianfan_ak=model_credential.get('api_key'),
32-
qianfan_sk=model_credential.get('secret_key'),
33-
streaming=model_kwargs.get('streaming', False),
34-
init_kwargs=optional_params)
31+
return QianfanChatModelQianfan(model=model_name,
32+
qianfan_ak=model_credential.get('api_key'),
33+
qianfan_sk=model_credential.get('secret_key'),
34+
streaming=model_kwargs.get('streaming', False),
35+
init_kwargs=optional_params)
3536

3637
usage_metadata: dict = {}
3738

@@ -74,3 +75,30 @@ def _stream(
7475
if run_manager:
7576
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
7677
yield chunk
78+
79+
80+
class QianfanChatModelOpenai(MaxKBBaseModel, BaseChatOpenAI):
81+
@staticmethod
82+
def is_cache_model():
83+
return False
84+
85+
@staticmethod
86+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
87+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
88+
return QianfanChatModelOpenai(
89+
model=model_name,
90+
openai_api_base=model_credential.get('api_base'),
91+
openai_api_key=model_credential.get('api_key'),
92+
extra_body=optional_params
93+
)
94+
95+
96+
class QianfanChatModel(MaxKBBaseModel):
97+
@staticmethod
98+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
99+
api_version = model_credential.get('api_version', 'v1')
100+
101+
if api_version == "v1":
102+
return QianfanChatModelQianfan.new_instance(model_type, model_name, model_credential, **model_kwargs)
103+
elif api_version == "v2":
104+
return QianfanChatModelOpenai.new_instance(model_type, model_name, model_credential, **model_kwargs)

0 commit comments

Comments
 (0)