11""" Repository pattern, errors and data structures for models.tags 
22""" 
3- 
4- from  typing  import  TypedDict 
5- 
3+ from  common_library .errors_classes  import  OsparcErrorMixin 
64from  sqlalchemy .ext .asyncio  import  AsyncConnection , AsyncEngine 
5+ from  typing_extensions  import  TypedDict 
76
87from  .utils_repos  import  pass_or_acquire_connection , transaction_context 
98from  .utils_tags_sql  import  (
9+     TagAccessRightsDict ,
1010    count_groups_with_given_access_rights_stmt ,
1111    create_tag_stmt ,
12+     delete_tag_access_rights_stmt ,
1213    delete_tag_stmt ,
1314    get_tag_stmt ,
15+     has_access_rights_stmt ,
16+     list_tag_group_access_stmt ,
1417    list_tags_stmt ,
15-     set_tag_access_rights_stmt ,
1618    update_tag_stmt ,
19+     upsert_tags_access_rights_stmt ,
1720)
1821
22+ __all__ : tuple [str , ...] =  ("TagAccessRightsDict" ,)
23+ 
1924
2025# 
2126# Errors 
2227# 
23- class  BaseTagError ( Exception ):
24-     pass 
28+ class  _BaseTagError ( OsparcErrorMixin ,  Exception ):
29+     msg_template   =   "Tag repo error on tag {tag_id}" 
2530
2631
27- class  TagNotFoundError (BaseTagError ):
32+ class  TagNotFoundError (_BaseTagError ):
2833    pass 
2934
3035
31- class  TagOperationNotAllowedError (BaseTagError ):  # maps to AccessForbidden 
36+ class  TagOperationNotAllowedError (_BaseTagError ):  # maps to AccessForbidden 
3237    pass 
3338
3439
@@ -108,7 +113,7 @@ async def create(
108113            assert  tag   # nosec 
109114
110115            # take tag ownership 
111-             access_stmt  =  set_tag_access_rights_stmt (
116+             access_stmt  =  upsert_tags_access_rights_stmt (
112117                tag_id = tag .id ,
113118                user_id = user_id ,
114119                read = read ,
@@ -163,8 +168,7 @@ async def get(
163168            result  =  await  conn .execute (stmt_get )
164169            row  =  result .first ()
165170            if  not  row :
166-                 msg  =  f"{ tag_id = }   not found: either no access or does not exists" 
167-                 raise  TagNotFoundError (msg )
171+                 raise  TagNotFoundError (operation = "get" , tag_id = tag_id , user_id = user_id )
168172            return  TagDict (
169173                id = row .id ,
170174                name = row .name ,
@@ -198,8 +202,9 @@ async def update(
198202            result  =  await  conn .execute (update_stmt )
199203            row  =  result .first ()
200204            if  not  row :
201-                 msg  =  f"{ tag_id = }   not updated: either no access or not found" 
202-                 raise  TagOperationNotAllowedError (msg )
205+                 raise  TagOperationNotAllowedError (
206+                     operation = "update" , tag_id = tag_id , user_id = user_id 
207+                 )
203208
204209            return  TagDict (
205210                id = row .id ,
@@ -222,44 +227,95 @@ async def delete(
222227        async  with  transaction_context (self .engine , connection ) as  conn :
223228            deleted  =  await  conn .scalar (stmt_delete )
224229            if  not  deleted :
225-                 msg  =  f"Could not delete { tag_id = }  . Not found or insuficient access." 
226-                 raise  TagOperationNotAllowedError (msg )
230+                 raise  TagOperationNotAllowedError (
231+                     operation = "delete" , tag_id = tag_id , user_id = user_id 
232+                 )
227233
228234    # 
229235    # ACCESS RIGHTS 
230236    # 
231237
232-     async  def  create_access_rights (
238+     async  def  has_access_rights (
233239        self ,
234240        connection : AsyncConnection  |  None  =  None ,
235241        * ,
236242        user_id : int ,
237243        tag_id : int ,
238-         group_id : int ,
239-         read : bool ,
240-         write : bool ,
241-         delete : bool ,
242-     ):
243-         raise  NotImplementedError 
244+         read : bool  =  False ,
245+         write : bool  =  False ,
246+         delete : bool  =  False ,
247+     ) ->  bool :
248+         async  with  pass_or_acquire_connection (self .engine , connection ) as  conn :
249+             group_id_or_none  =  await  conn .scalar (
250+                 has_access_rights_stmt (
251+                     tag_id = tag_id ,
252+                     caller_user_id = user_id ,
253+                     read = read ,
254+                     write = write ,
255+                     delete = delete ,
256+                 )
257+             )
258+             return  bool (group_id_or_none )
244259
245-     async  def  update_access_rights (
260+     async  def  list_access_rights (
261+         self ,
262+         connection : AsyncConnection  |  None  =  None ,
263+         * ,
264+         tag_id : int ,
265+     ) ->  list [TagAccessRightsDict ]:
266+         async  with  pass_or_acquire_connection (self .engine , connection ) as  conn :
267+             result  =  await  conn .execute (list_tag_group_access_stmt (tag_id = tag_id ))
268+             return  [
269+                 TagAccessRightsDict (
270+                     tag_id = row .tag_id ,
271+                     group_id = row .group_id ,
272+                     read = row .read ,
273+                     write = row .write ,
274+                     delete = row .delete ,
275+                 )
276+                 for  row  in  result .fetchall ()
277+             ]
278+ 
279+     async  def  create_or_update_access_rights (
246280        self ,
247281        connection : AsyncConnection  |  None  =  None ,
248282        * ,
249-         user_id : int ,
250283        tag_id : int ,
251284        group_id : int ,
252285        read : bool ,
253286        write : bool ,
254287        delete : bool ,
255-     ):
256-         raise  NotImplementedError 
288+     ) ->  TagAccessRightsDict :
289+         async  with  transaction_context (self .engine , connection ) as  conn :
290+             result  =  await  conn .execute (
291+                 upsert_tags_access_rights_stmt (
292+                     tag_id = tag_id ,
293+                     group_id = group_id ,
294+                     read = read ,
295+                     write = write ,
296+                     delete = delete ,
297+                 )
298+             )
299+             row  =  result .first ()
300+             assert  row  is  not   None 
301+ 
302+             return  TagAccessRightsDict (
303+                 tag_id = row .tag_id ,
304+                 group_id = row .group_id ,
305+                 read = row .read ,
306+                 write = row .write ,
307+                 delete = row .delete ,
308+             )
257309
258310    async  def  delete_access_rights (
259311        self ,
260312        connection : AsyncConnection  |  None  =  None ,
261313        * ,
262-         user_id : int ,
263314        tag_id : int ,
264-     ):
265-         raise  NotImplementedError 
315+         group_id : int ,
316+     ) ->  bool :
317+         async  with  transaction_context (self .engine , connection ) as  conn :
318+             deleted : bool  =  await  conn .scalar (
319+                 delete_tag_access_rights_stmt (tag_id = tag_id , group_id = group_id )
320+             )
321+             return  deleted 
0 commit comments