@@ -21,11 +21,27 @@ class QianfanEmbeddingCredential(BaseForm, BaseModelCredential):
2121
2222 def is_valid (self , model_type : str , model_name , model_credential : Dict [str , object ], model_params , provider ,
2323 raise_exception = False ):
24- model_type_list = provider .get_model_type_list ()
25- if not any (list (filter (lambda mt : mt .get ('value' ) == model_type , model_type_list ))):
26- raise AppApiException (ValidCode .valid_error .value ,
27- _ ('{model_type} Model type is not supported' ).format (model_type = model_type ))
28- self .valid_form (model_credential )
24+ api_version = model_credential .get ('api_version' , 'v1' )
25+ model = provider .get_model (model_type , model_name , model_credential , ** model_params )
26+ if api_version == 'v1' :
27+ model_type_list = provider .get_model_type_list ()
28+ if not any (list (filter (lambda mt : mt .get ('value' ) == model_type , model_type_list ))):
29+ raise AppApiException (ValidCode .valid_error .value ,
30+ _ ('{model_type} Model type is not supported' ).format (model_type = model_type ))
31+ model_info = [model .lower () for model in model .client .models ()]
32+ if not model_info .__contains__ (model_name .lower ()):
33+ raise AppApiException (ValidCode .valid_error .value ,
34+ _ ('{model_name} The model does not support' ).format (model_name = model_name ))
35+ required_keys = ['qianfan_ak' , 'qianfan_sk' ]
36+ if api_version == 'v2' :
37+ required_keys = ['api_base' , 'qianfan_ak' ]
38+
39+ for key in required_keys :
40+ if key not in model_credential :
41+ if raise_exception :
42+ raise AppApiException (ValidCode .valid_error .value , _ ('{key} is required' ).format (key = key ))
43+ else :
44+ return False
2945 try :
3046 model = provider .get_model (model_type , model_name , model_credential )
3147 model .embed_query (_ ('Hello' ))
@@ -42,8 +58,25 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
4258 return True
4359
4460 def encryption_dict (self , model : Dict [str , object ]):
45- return {** model , 'qianfan_sk' : super ().encryption (model .get ('qianfan_sk' , '' ))}
61+ api_version = model .get ('api_version' , 'v1' )
62+ if api_version == 'v1' :
63+ return {** model , 'qianfan_sk' : super ().encryption (model .get ('qianfan_sk' , '' ))}
64+ else : # v2
65+ return {** model , 'qianfan_ak' : super ().encryption (model .get ('qianfan_ak' , '' ))}
4666
47- qianfan_ak = forms .PasswordInputField ('API Key' , required = True )
67+ api_version = forms .Radio ('API Version' , required = True , text_field = 'label' , value_field = 'value' ,
68+ option_list = [
69+ {'label' : 'v1' , 'value' : 'v1' },
70+ {'label' : 'v2' , 'value' : 'v2' }
71+ ],
72+ default_value = 'v1' ,
73+ provider = '' ,
74+ method = '' , )
75+
76+ # v2版本字段
77+ api_base = forms .TextInputField ("API URL" , required = True , relation_show_field_dict = {"api_version" : ["v2" ]})
4878
49- qianfan_sk = forms .PasswordInputField ("Secret Key" , required = True )
79+ # v1版本字段
80+ qianfan_ak = forms .PasswordInputField ('API Key' , required = True )
81+ qianfan_sk = forms .PasswordInputField ("Secret Key" , required = True ,
82+ relation_show_field_dict = {"api_version" : ["v1" ]})
0 commit comments