11""" Repository pattern, errors and data structures for models.tags
22"""
3-
4- from typing import TypedDict
5-
3+ from common_library .errors_classes import OsparcErrorMixin
64from sqlalchemy .ext .asyncio import AsyncConnection , AsyncEngine
5+ from typing_extensions import TypedDict
76
87from .utils_repos import pass_or_acquire_connection , transaction_context
98from .utils_tags_sql import (
9+ TagAccessRightsDict ,
1010 count_groups_with_given_access_rights_stmt ,
1111 create_tag_stmt ,
12+ delete_tag_access_rights_stmt ,
1213 delete_tag_stmt ,
1314 get_tag_stmt ,
15+ has_access_rights_stmt ,
16+ list_tag_group_access_stmt ,
1417 list_tags_stmt ,
15- set_tag_access_rights_stmt ,
1618 update_tag_stmt ,
19+ upsert_tags_access_rights_stmt ,
1720)
1821
22+ __all__ : tuple [str , ...] = ("TagAccessRightsDict" ,)
23+
1924
2025#
2126# Errors
2227#
23- class BaseTagError ( Exception ):
24- pass
28+ class _BaseTagError ( OsparcErrorMixin , Exception ):
29+ msg_template = "Tag repo error on tag {tag_id}"
2530
2631
27- class TagNotFoundError (BaseTagError ):
32+ class TagNotFoundError (_BaseTagError ):
2833 pass
2934
3035
31- class TagOperationNotAllowedError (BaseTagError ): # maps to AccessForbidden
36+ class TagOperationNotAllowedError (_BaseTagError ): # maps to AccessForbidden
3237 pass
3338
3439
@@ -108,7 +113,7 @@ async def create(
108113 assert tag # nosec
109114
110115 # take tag ownership
111- access_stmt = set_tag_access_rights_stmt (
116+ access_stmt = upsert_tags_access_rights_stmt (
112117 tag_id = tag .id ,
113118 user_id = user_id ,
114119 read = read ,
@@ -163,8 +168,7 @@ async def get(
163168 result = await conn .execute (stmt_get )
164169 row = result .first ()
165170 if not row :
166- msg = f"{ tag_id = } not found: either no access or does not exists"
167- raise TagNotFoundError (msg )
171+ raise TagNotFoundError (operation = "get" , tag_id = tag_id , user_id = user_id )
168172 return TagDict (
169173 id = row .id ,
170174 name = row .name ,
@@ -198,8 +202,9 @@ async def update(
198202 result = await conn .execute (update_stmt )
199203 row = result .first ()
200204 if not row :
201- msg = f"{ tag_id = } not updated: either no access or not found"
202- raise TagOperationNotAllowedError (msg )
205+ raise TagOperationNotAllowedError (
206+ operation = "update" , tag_id = tag_id , user_id = user_id
207+ )
203208
204209 return TagDict (
205210 id = row .id ,
@@ -222,44 +227,95 @@ async def delete(
222227 async with transaction_context (self .engine , connection ) as conn :
223228 deleted = await conn .scalar (stmt_delete )
224229 if not deleted :
225- msg = f"Could not delete { tag_id = } . Not found or insuficient access."
226- raise TagOperationNotAllowedError (msg )
230+ raise TagOperationNotAllowedError (
231+ operation = "delete" , tag_id = tag_id , user_id = user_id
232+ )
227233
228234 #
229235 # ACCESS RIGHTS
230236 #
231237
232- async def create_access_rights (
238+ async def has_access_rights (
233239 self ,
234240 connection : AsyncConnection | None = None ,
235241 * ,
236242 user_id : int ,
237243 tag_id : int ,
238- group_id : int ,
239- read : bool ,
240- write : bool ,
241- delete : bool ,
242- ):
243- raise NotImplementedError
244+ read : bool = False ,
245+ write : bool = False ,
246+ delete : bool = False ,
247+ ) -> bool :
248+ async with pass_or_acquire_connection (self .engine , connection ) as conn :
249+ group_id_or_none = await conn .scalar (
250+ has_access_rights_stmt (
251+ tag_id = tag_id ,
252+ caller_user_id = user_id ,
253+ read = read ,
254+ write = write ,
255+ delete = delete ,
256+ )
257+ )
258+ return bool (group_id_or_none )
244259
245- async def update_access_rights (
260+ async def list_access_rights (
261+ self ,
262+ connection : AsyncConnection | None = None ,
263+ * ,
264+ tag_id : int ,
265+ ) -> list [TagAccessRightsDict ]:
266+ async with pass_or_acquire_connection (self .engine , connection ) as conn :
267+ result = await conn .execute (list_tag_group_access_stmt (tag_id = tag_id ))
268+ return [
269+ TagAccessRightsDict (
270+ tag_id = row .tag_id ,
271+ group_id = row .group_id ,
272+ read = row .read ,
273+ write = row .write ,
274+ delete = row .delete ,
275+ )
276+ for row in result .fetchall ()
277+ ]
278+
279+ async def create_or_update_access_rights (
246280 self ,
247281 connection : AsyncConnection | None = None ,
248282 * ,
249- user_id : int ,
250283 tag_id : int ,
251284 group_id : int ,
252285 read : bool ,
253286 write : bool ,
254287 delete : bool ,
255- ):
256- raise NotImplementedError
288+ ) -> TagAccessRightsDict :
289+ async with transaction_context (self .engine , connection ) as conn :
290+ result = await conn .execute (
291+ upsert_tags_access_rights_stmt (
292+ tag_id = tag_id ,
293+ group_id = group_id ,
294+ read = read ,
295+ write = write ,
296+ delete = delete ,
297+ )
298+ )
299+ row = result .first ()
300+ assert row is not None
301+
302+ return TagAccessRightsDict (
303+ tag_id = row .tag_id ,
304+ group_id = row .group_id ,
305+ read = row .read ,
306+ write = row .write ,
307+ delete = row .delete ,
308+ )
257309
258310 async def delete_access_rights (
259311 self ,
260312 connection : AsyncConnection | None = None ,
261313 * ,
262- user_id : int ,
263314 tag_id : int ,
264- ):
265- raise NotImplementedError
315+ group_id : int ,
316+ ) -> bool :
317+ async with transaction_context (self .engine , connection ) as conn :
318+ deleted : bool = await conn .scalar (
319+ delete_tag_access_rights_stmt (tag_id = tag_id , group_id = group_id )
320+ )
321+ return deleted
0 commit comments