Skip to content

Commit 6e0e0d2

Browse files
committed
refactor: model api
1 parent 66c868e commit 6e0e0d2

File tree

3 files changed

+86
-44
lines changed

3 files changed

+86
-44
lines changed

apps/models_provider/api/model.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,52 @@ def get_data(self):
2323

2424
return ModelListResult
2525

26+
@staticmethod
27+
def get_parameters():
28+
return [OpenApiParameter(
29+
name="workspace_id",
30+
description=_("workspace id"),
31+
type=OpenApiTypes.STR,
32+
location=OpenApiParameter.PATH,
33+
required=True,
34+
),
35+
OpenApiParameter(
36+
name="name",
37+
description=_("model name"),
38+
type=OpenApiTypes.STR,
39+
location=OpenApiParameter.QUERY,
40+
required=False,
41+
),
42+
OpenApiParameter(
43+
name="model_type",
44+
description=_("model type"),
45+
type=OpenApiTypes.STR,
46+
location=OpenApiParameter.QUERY,
47+
required=False,
48+
),
49+
OpenApiParameter(
50+
name="model_name",
51+
description=_("base model"),
52+
type=OpenApiTypes.STR,
53+
location=OpenApiParameter.QUERY,
54+
required=False,
55+
),
56+
OpenApiParameter(
57+
name="provider",
58+
description=_("provider"),
59+
type=OpenApiTypes.STR,
60+
location=OpenApiParameter.QUERY,
61+
required=False,
62+
),
63+
OpenApiParameter(
64+
name="create_user",
65+
description=_("create user"),
66+
type=OpenApiTypes.STR,
67+
location=OpenApiParameter.QUERY,
68+
required=False,
69+
)
70+
]
71+
2672

2773
class ModelCreateAPI(APIMixin):
2874
@staticmethod
@@ -34,7 +80,7 @@ def get_response():
3480
return ModelCreateResponse
3581

3682
@classmethod
37-
def get_query_params_api(cls):
83+
def get_parameters(cls):
3884
return [OpenApiParameter(
3985
name="workspace_id",
4086
description=_("workspace id"),

apps/models_provider/serializers/model_serializer.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def model_to_dict(model: Model):
105105

106106
class Operate(serializers.Serializer):
107107
id = serializers.UUIDField(required=True, label=_("model id"))
108-
user_id = serializers.UUIDField(required=True, label=_("user id"))
108+
user_id = serializers.UUIDField(required=False, label=_("user id"))
109109

110110
def is_valid(self, *, raise_exception=False):
111111
super().is_valid(raise_exception=True)
@@ -114,6 +114,8 @@ def is_valid(self, *, raise_exception=False):
114114
).first()
115115
if model is None:
116116
raise AppApiException(500, _('Model does not exist'))
117+
if model.workspace_id == 'None':
118+
raise AppApiException(500, _('Shared models cannot be deleted or modified'))
117119

118120
def one(self, with_valid=False):
119121
if with_valid:
@@ -147,8 +149,6 @@ def delete(self, with_valid=True):
147149
self.is_valid(raise_exception=True)
148150
model_id = self.data.get('id')
149151
model = Model.objects.filter(id=model_id).first()
150-
if not model:
151-
raise AppApiException(500, _("Model does not exist"))
152152
# TODO : 这里可以添加模型删除的逻辑,需要注意删除模型时的权限和关联关系
153153
# if model.model_type == 'LLM':
154154
# application_count = Application.objects.filter(model_id=model_id).count()
@@ -174,35 +174,32 @@ def edit(self, instance: Dict, user_id: str, with_valid=True):
174174
self.is_valid(raise_exception=True)
175175
model = QuerySet(Model).filter(id=self.data.get('id')).first()
176176

177-
if model is None:
178-
raise AppApiException(500, _('Model does not exist'))
179-
else:
180-
credential, model_credential, provider_handler = ModelSerializer.Edit(
181-
data={**instance}).is_valid(
182-
model=model)
183-
try:
184-
model.status = Status.SUCCESS
185-
default_params = {item['field']: item['default_value'] for item in model.model_params_form}
186-
# 校验模型认证数据
187-
provider_handler.is_valid_credential(model.model_type,
188-
instance.get("model_name"),
189-
credential,
190-
default_params,
191-
raise_exception=True)
192-
193-
except AppApiException as e:
194-
if e.code == ValidCode.model_not_fount:
195-
model.status = Status.DOWNLOAD
177+
credential, model_credential, provider_handler = ModelSerializer.Edit(
178+
data={**instance}).is_valid(
179+
model=model)
180+
try:
181+
model.status = Status.SUCCESS
182+
default_params = {item['field']: item['default_value'] for item in model.model_params_form}
183+
# 校验模型认证数据
184+
provider_handler.is_valid_credential(model.model_type,
185+
instance.get("model_name"),
186+
credential,
187+
default_params,
188+
raise_exception=True)
189+
190+
except AppApiException as e:
191+
if e.code == ValidCode.model_not_fount:
192+
model.status = Status.DOWNLOAD
193+
else:
194+
raise e
195+
update_keys = ['credential', 'name', 'model_type', 'model_name']
196+
for update_key in update_keys:
197+
if update_key in instance and instance.get(update_key) is not None:
198+
if update_key == 'credential':
199+
model_credential_str = json.dumps(credential)
200+
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
196201
else:
197-
raise e
198-
update_keys = ['credential', 'name', 'model_type', 'model_name']
199-
for update_key in update_keys:
200-
if update_key in instance and instance.get(update_key) is not None:
201-
if update_key == 'credential':
202-
model_credential_str = json.dumps(credential)
203-
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
204-
else:
205-
model.__setattr__(update_key, instance.get(update_key))
202+
model.__setattr__(update_key, instance.get(update_key))
206203

207204
ModelManage.delete_key(str(model.id))
208205
model.save()

apps/models_provider/views/model.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class ModelSetting(APIView):
6161
description=_("Create model"),
6262
operation_id=_("Create model"), # type: ignore
6363
tags=[_("Model")], # type: ignore
64-
parameters=ModelCreateAPI.get_query_params_api(),
64+
parameters=ModelCreateAPI.get_parameters(),
6565
request=ModelCreateAPI.get_request(),
6666
responses=ModelCreateAPI.get_response())
6767
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission())
@@ -90,7 +90,7 @@ def post(self, request: Request, workspace_id: str):
9090
summary=_('Query model list'),
9191
description=_('Query model list'),
9292
operation_id=_('Query model list'), # type: ignore
93-
parameters=ModelCreateAPI.get_query_params_api(),
93+
parameters=ModelListResponse.get_parameters(),
9494
responses=ModelListResponse.get_response(),
9595
tags=[_('Model')]) # type: ignore
9696
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
@@ -108,7 +108,7 @@ class Operate(APIView):
108108
description=_('Update model'),
109109
operation_id=_('Update model'), # type: ignore
110110
request=ModelEditApi.get_request(),
111-
parameters=GetModelApi.get_query_params_api(),
111+
parameters=GetModelApi.get_parameters(),
112112
responses=ModelEditApi.get_response(),
113113
tags=[_('Model')]) # type: ignore
114114
@has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission())
@@ -125,7 +125,7 @@ def put(self, request: Request, workspace_id, model_id: str):
125125
summary=_('Delete model'),
126126
description=_('Delete model'),
127127
operation_id=_('Delete model'), # type: ignore
128-
parameters=GetModelApi.get_query_params_api(),
128+
parameters=GetModelApi.get_parameters(),
129129
responses=DefaultModelResponse.get_response(),
130130
tags=[_('Model')]) # type: ignore
131131
@has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission())
@@ -139,7 +139,7 @@ def delete(self, request: Request, workspace_id: str, model_id: str):
139139
summary=_('Query model details'),
140140
description=_('Query model details'),
141141
operation_id=_('Query model details'), # type: ignore
142-
parameters=GetModelApi.get_query_params_api(),
142+
parameters=GetModelApi.get_parameters(),
143143
responses=GetModelApi.get_response(),
144144
tags=[_('Model')]) # type: ignore
145145
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
@@ -154,7 +154,7 @@ class ModelParamsForm(APIView):
154154
summary=_('Get model parameter form'),
155155
description=_('Get model parameter form'),
156156
operation_id=_('Get model parameter form'), # type: ignore
157-
parameters=GetModelApi.get_query_params_api(),
157+
parameters=GetModelApi.get_parameters(),
158158
responses=ProvideApi.ModelParamsForm.get_response(),
159159
tags=[_('Model')]) # type: ignore
160160
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
@@ -166,7 +166,7 @@ def get(self, request: Request, workspace_id: str, model_id: str):
166166
summary=_('Save model parameter form'),
167167
description=_('Save model parameter form'),
168168
operation_id=_('Save model parameter form'), # type: ignore
169-
parameters=GetModelApi.get_query_params_api(),
169+
parameters=GetModelApi.get_parameters(),
170170
request=GetModelApi.get_request(),
171171
responses=ProvideApi.ModelParamsForm.get_response(),
172172
tags=[_('Model')]) # type: ignore
@@ -187,7 +187,7 @@ class ModelMeta(APIView):
187187
'Query model meta information, this interface does not carry authentication information'),
188188
operation_id=_(
189189
'Query model meta information, this interface does not carry authentication information'),
190-
parameters=GetModelApi.get_query_params_api(),
190+
parameters=GetModelApi.get_parameters(),
191191
responses=GetModelApi.get_response(),
192192
tags=[_('Model')]) # type: ignore
193193
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
@@ -202,7 +202,7 @@ class PauseDownload(APIView):
202202
summary=_('Pause model download'),
203203
description=_('Pause model download'),
204204
operation_id=_('Pause model download'), # type: ignore
205-
parameters=GetModelApi.get_query_params_api(),
205+
parameters=GetModelApi.get_parameters(),
206206
request=GetModelApi.get_request(),
207207
responses=DefaultModelResponse.get_response(),
208208
tags=[_('Model')]) # type: ignore
@@ -218,9 +218,8 @@ class Share(APIView):
218218
summary=_('Get Share model'),
219219
description=_('Get Share model'),
220220
operation_id=_('Get Share model'), # type: ignore
221-
parameters=GetModelApi.get_query_params_api(),
222-
request=GetModelApi.get_request(),
223-
responses=DefaultModelResponse.get_response(),
221+
parameters=ModelListResponse.get_parameters(),
222+
responses=ModelListResponse.get_response(),
224223
tags=[_('Model')]) # type: ignore
225224
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
226225
def get(self, request: Request, workspace_id: str):

0 commit comments

Comments
 (0)