Skip to content

Commit 823d529

Browse files
committed
refactor: check model use model_params
1 parent 628cf70 commit 823d529

File tree

66 files changed

+116
-107
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+116
-107
lines changed

apps/setting/models_provider/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def get_model_type_list(provider):
8181
return get_provider(provider).get_model_type_list()
8282

8383

84-
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
84+
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], model_params, raise_exception=False):
8585
"""
8686
校验模型认证参数
8787
@param provider: 供应商字符串
@@ -91,4 +91,4 @@ def is_valid_credential(provider, model_type, model_name, model_credential: Dict
9191
@param raise_exception: 是否抛出错误
9292
@return: True|False
9393
"""
94-
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception)
94+
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, model_params, raise_exception)

apps/setting/models_provider/base_model_provider.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,13 @@ def get_model_credential(self, model_type, model_name):
6767
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
6868
return model_info.model_credential
6969

70-
def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
70+
def get_model_params(self, model_type, model_name):
7171
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
72-
return model_info.model_credential.is_valid(model_type, model_name, model_credential, self,
72+
return model_info.model_credential
73+
74+
def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], model_params: Dict[str, object], raise_exception=False):
75+
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
76+
return model_info.model_credential.is_valid(model_type, model_name, model_credential, model_params, self,
7377
raise_exception=raise_exception)
7478

7579
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
@@ -105,7 +109,7 @@ def filter_optional_params(model_kwargs):
105109
class BaseModelCredential(ABC):
106110

107111
@abstractmethod
108-
def is_valid(self, model_type: str, model_name, model: Dict[str, object], provider, raise_exception=True):
112+
def is_valid(self, model_type: str, model_name, model: Dict[str, object], model_params, provider, raise_exception=True):
109113
pass
110114

111115
@abstractmethod

apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
1919

20-
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
20+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
2121
raise_exception=False):
2222
model_type_list = provider.get_model_type_list()
2323
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):

apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class QwenModelParams(BaseForm):
3737

3838
class QwenVLModelCredential(BaseForm, BaseModelCredential):
3939

40-
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
40+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
4141
raise_exception=False):
4242
model_type_list = provider.get_model_type_list()
4343
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
@@ -49,7 +49,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
4949
else:
5050
return False
5151
try:
52-
model = provider.get_model(model_type, model_name, model_credential)
52+
model = provider.get_model(model_type, model_name, model_credential, **model_params)
5353
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
5454
for chunk in res:
5555
print(chunk)

apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class BaiLianLLMModelParams(BaseForm):
2828

2929
class BaiLianLLMModelCredential(BaseForm, BaseModelCredential):
3030

31-
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
31+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
3232
raise_exception=False):
3333
model_type_list = provider.get_model_type_list()
3434
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
@@ -41,7 +41,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
4141
else:
4242
return False
4343
try:
44-
model = provider.get_model(model_type, model_name, model_credential)
44+
model = provider.get_model(model_type, model_name, model_credential, **model_params)
4545
model.invoke([HumanMessage(content='你好')])
4646
except Exception as e:
4747
if isinstance(e, AppApiException):

apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):
2121

22-
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
22+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
2323
raise_exception=False):
2424
if not model_type == 'RERANKER':
2525
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
1212
api_key = forms.PasswordInputField("API Key", required=True)
1313

14-
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
14+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
1515
raise_exception=False):
1616
model_type_list = provider.get_model_type_list()
1717
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):

apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class QwenModelParams(BaseForm):
6161

6262
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
6363

64-
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
64+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
6565
raise_exception=False):
6666
model_type_list = provider.get_model_type_list()
6767
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
@@ -73,7 +73,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
7373
else:
7474
return False
7575
try:
76-
model = provider.get_model(model_type, model_name, model_credential)
76+
model = provider.get_model(model_type, model_name, model_credential, **model_params)
7777
res = model.check_auth()
7878
print(res)
7979
except Exception as e:

apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class AliyunBaiLianTTSModelGeneralParams(BaseForm):
4545
class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
4646
api_key = forms.PasswordInputField("API Key", required=True)
4747

48-
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
48+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
4949
raise_exception=False):
5050
model_type_list = provider.get_model_type_list()
5151
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
@@ -58,7 +58,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
5858
else:
5959
return False
6060
try:
61-
model = provider.get_model(model_type, model_name, model_credential)
61+
model = provider.get_model(model_type, model_name, model_credential, **model_params)
6262
model.check_auth()
6363
except Exception as e:
6464
if isinstance(e, AppApiException):

apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
2626
with open(credentials_path, 'w') as file:
2727
file.write(content)
2828

29-
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
29+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
3030
raise_exception=False):
3131
model_type_list = provider.get_model_type_list()
3232
if not any(mt.get('value') == model_type for mt in model_type_list):

0 commit comments

Comments
 (0)