Skip to content

Commit 37a2041

Browse files
committed
refactor: shared model
1 parent 2c20733 commit 37a2041

File tree

3 files changed

+65
-21
lines changed

3 files changed

+65
-21
lines changed

apps/models_provider/serializers/model_serializer.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from rest_framework import serializers
1111

1212
from common.config.embedding_config import ModelManage
13+
from common.database_model_manage.database_model_manage import DatabaseModelManage
1314
from common.exception.app_exception import AppApiException
1415
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
1516
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
@@ -394,7 +395,22 @@ def save_model_params_form(self, model_params_form, with_valid=True):
394395
return True
395396

396397

397-
class SharedModelSerializer(serializers.Serializer):
398+
def get_authorized_tool(tool_query_set, workspace_id, model_workspace_authorization):
399+
white_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
400+
workspace_id=workspace_id, authentication_type='WHITE_LIST'
401+
).values_list('model_id', flat=True)
402+
black_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
403+
workspace_id=workspace_id, authentication_type='BLACK_LIST'
404+
).values_list('model_id', flat=True)
405+
tool_query_set = tool_query_set.filter(
406+
id__in=white_authorized_tool_ids
407+
).exclude(
408+
id__in=black_authorized_tool_ids
409+
)
410+
return tool_query_set
411+
412+
413+
class WorkspaceSharedModelSerializer(serializers.Serializer):
398414
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
399415
name = serializers.CharField(required=False, max_length=64, label=_('model name'))
400416
model_type = serializers.CharField(required=False, label=_('model type'))
@@ -404,7 +420,10 @@ class SharedModelSerializer(serializers.Serializer):
404420

405421
def get_share_model_list(self):
406422
self.is_valid(raise_exception=True)
407-
queryset = QuerySet(Model).filter(workspace_id='None')
423+
workspace_id = self.data.get('workspace_id')
424+
425+
queryset = self._build_queryset(workspace_id)
426+
408427
return [
409428
{
410429
'id': str(model.id),
@@ -419,3 +438,23 @@ def get_share_model_list(self):
419438
}
420439
for model in queryset.order_by("-create_time")
421440
]
441+
442+
def _build_queryset(self, workspace_id):
443+
queryset = QuerySet(Model)
444+
if workspace_id:
445+
model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization")
446+
if model_workspace_authorization is not None:
447+
queryset = get_authorized_tool(queryset, workspace_id,
448+
model_workspace_authorization=model_workspace_authorization)
449+
450+
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
451+
value = self.data.get(field)
452+
if value is not None:
453+
if field == 'name':
454+
queryset = queryset.filter(**{f'{field}__icontains': value})
455+
elif field == 'create_user':
456+
queryset = queryset.filter(user_id=value)
457+
else:
458+
queryset = queryset.filter(**{field: value})
459+
460+
return queryset

apps/models_provider/urls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
path('workspace/<str:workspace_id>/model/<str:model_id>/pause_download',
1919
views.ModelSetting.PauseDownload.as_view()),
2020
path('workspace/<str:workspace_id>/model/<str:model_id>/meta', views.ModelSetting.ModelMeta.as_view()),
21-
path('workspace/<str:workspace_id>/shared/model', views.SharedModel.as_view()),
21+
path('system/shared/workspace/<str:workspace_id>/model', views.WorkspaceSharedModelSetting.as_view()),
2222
]
2323

2424
if os.environ.get('SERVER_NAME', 'web') == 'local_model':

apps/models_provider/views/model.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from models_provider.api.model import ModelCreateAPI, GetModelApi, ModelEditApi, ModelListResponse, DefaultModelResponse
2222
from models_provider.api.provide import ProvideApi
2323
from models_provider.models import Model
24-
from models_provider.serializers.model_serializer import ModelSerializer, SharedModelSerializer
24+
from models_provider.serializers.model_serializer import ModelSerializer, SharedModelSerializer, \
25+
WorkspaceSharedModelSerializer
2526
from system_manage.views import encryption_str
2627

2728

@@ -65,7 +66,7 @@ class ModelSetting(APIView):
6566
request=ModelCreateAPI.get_request(),
6667
responses=ModelCreateAPI.get_response())
6768
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission(),
68-
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
69+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
6970
@log(menu='model', operate='Create model',
7071
get_operation_object=lambda r, k: {'name': r.date.get('name')},
7172
get_details=get_edit_model_details,
@@ -95,7 +96,7 @@ def post(self, request: Request, workspace_id: str):
9596
responses=ModelListResponse.get_response(),
9697
tags=[_('Model')]) # type: ignore
9798
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
98-
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
99+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
99100
def get(self, request: Request, workspace_id: str):
100101
return result.success(
101102
ModelSerializer.Query(
@@ -114,7 +115,7 @@ class Operate(APIView):
114115
responses=ModelEditApi.get_response(),
115116
tags=[_('Model')]) # type: ignore
116117
@has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission(),
117-
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
118+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
118119
@log(menu='model', operate='Update model',
119120
get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
120121
get_details=get_edit_model_details,
@@ -133,7 +134,7 @@ def put(self, request: Request, workspace_id, model_id: str):
133134
responses=DefaultModelResponse.get_response(),
134135
tags=[_('Model')]) # type: ignore
135136
@has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission(),
136-
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
137+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
137138
@log(menu='model', operate='Delete model',
138139
get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
139140
)
@@ -150,7 +151,7 @@ def delete(self, request: Request, workspace_id: str, model_id: str):
150151
responses=GetModelApi.get_response(),
151152
tags=[_('Model')]) # type: ignore
152153
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
153-
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
154+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
154155
def get(self, request: Request, workspace_id: str, model_id: str):
155156
return result.success(
156157
ModelSerializer.Operate(
@@ -168,7 +169,7 @@ class ModelParamsForm(APIView):
168169
responses=ProvideApi.ModelParamsForm.get_response(),
169170
tags=[_('Model')]) # type: ignore
170171
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
171-
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
172+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
172173
def get(self, request: Request, workspace_id: str, model_id: str):
173174
return result.success(
174175
ModelSerializer.ModelParams(data={'id': model_id}).get_model_params())
@@ -182,7 +183,7 @@ def get(self, request: Request, workspace_id: str, model_id: str):
182183
responses=ProvideApi.ModelParamsForm.get_response(),
183184
tags=[_('Model')]) # type: ignore
184185
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
185-
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
186+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
186187
@log(menu='model', operate='Save model parameter form',
187188
get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
188189
)
@@ -204,7 +205,7 @@ class ModelMeta(APIView):
204205
responses=GetModelApi.get_response(),
205206
tags=[_('Model')]) # type: ignore
206207
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
207-
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
208+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
208209
def get(self, request: Request, workspace_id: str, model_id: str):
209210
return result.success(
210211
ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).one_meta(with_valid=True))
@@ -221,25 +222,29 @@ class PauseDownload(APIView):
221222
responses=DefaultModelResponse.get_response(),
222223
tags=[_('Model')]) # type: ignore
223224
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission(),
224-
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
225+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
225226
def put(self, request: Request, workspace_id: str, model_id: str):
226227
return result.success(
227228
ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).pause_download())
228229

229230

230-
class SharedModel(APIView):
231+
class WorkspaceSharedModelSetting(APIView):
231232
authentication_classes = [TokenAuth]
232233

233234
@extend_schema(
234235
methods=['Get'],
235-
summary=_('Get Share model'),
236-
description=_('Get Share model'),
237-
operation_id=_('Get Share model'), # type: ignore
238-
parameters=ModelCreateAPI.get_parameters(),
239-
responses=ModelListResponse.get_response(),
236+
summary=_('Get Share model by workspace id'),
237+
description=_('Get Share model by workspace id'),
238+
operation_id=_('Get Share model by workspace id'), # type: ignore
239+
parameters=ModelListResponse.get_parameters(),
240+
responses=DefaultModelResponse.get_response(),
240241
tags=[_('Shared Model')]
241242
) # type: ignore
242-
@has_permissions(PermissionConstants.MODEL_READ, RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
243+
@has_permissions(
244+
PermissionConstants.MODEL_READ.get_workspace_permission(),
245+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
246+
RoleConstants.USER.get_workspace_role(),
247+
)
243248
def get(self, request: Request, workspace_id: str):
244249
return result.success(
245-
SharedModelSerializer(data={'workspace_id': workspace_id}).get_share_model_list())
250+
WorkspaceSharedModelSerializer(data={'workspace_id': workspace_id}).get_share_model_list())

0 commit comments

Comments
 (0)