22"""
33
44import itertools
5- from dataclasses import dataclass
65from typing import TypedDict
76
8- from aiopg . sa . connection import SAConnection
7+ from sqlalchemy . ext . asyncio import AsyncConnection , AsyncEngine
98
9+ from .base_repo import get_or_create_connection , transaction_context
1010from .utils_tags_sql import (
1111 count_users_with_access_rights_stmt ,
1212 create_tag_stmt ,
@@ -49,15 +49,16 @@ class TagDict(TypedDict, total=True):
4949 delete : bool
5050
5151
52- @dataclass (frozen = True )
5352class TagsRepo :
54- user_id : int # Determines access-rights
53+ def __init__ (self , engine : AsyncEngine ):
54+ self .engine = engine
5555
5656 async def access_count (
5757 self ,
58- conn : SAConnection ,
59- tag_id : int ,
58+ connection : AsyncConnection | None = None ,
6059 * ,
60+ user_id : int ,
61+ tag_id : int ,
6162 read : bool | None = None ,
6263 write : bool | None = None ,
6364 delete : bool | None = None ,
@@ -66,20 +67,22 @@ async def access_count(
6667 Returns 0 if tag does not match access
6768 Returns >0 if it does and represents the number of groups granting this access to the user
6869 """
69- count_stmt = count_users_with_access_rights_stmt (
70- user_id = self .user_id , tag_id = tag_id , read = read , write = write , delete = delete
71- )
72- permissions_count : int | None = await conn .scalar (count_stmt )
73- return permissions_count if permissions_count else 0
70+ async with get_or_create_connection (self .engine , connection ) as conn :
71+ count_stmt = count_users_with_access_rights_stmt (
72+ user_id = user_id , tag_id = tag_id , read = read , write = write , delete = delete
73+ )
74+ permissions_count : int | None = await conn .scalar (count_stmt )
75+ return permissions_count if permissions_count else 0
7476
7577 #
7678 # CRUD operations
7779 #
7880
7981 async def create (
8082 self ,
81- conn : SAConnection ,
83+ connection : AsyncConnection | None = None ,
8284 * ,
85+ user_id : int ,
8386 name : str ,
8487 color : str ,
8588 description : str | None = None , # =nullable
@@ -94,69 +97,91 @@ async def create(
9497 if description :
9598 values ["description" ] = description
9699
97- async with conn . begin () :
100+ async with transaction_context ( self . engine , connection ) as conn :
98101 # insert new tag
99102 insert_stmt = create_tag_stmt (** values )
100103 result = await conn .execute (insert_stmt )
101- tag = await result .first ()
104+ tag = result .first ()
102105 assert tag # nosec
103106
104107 # take tag ownership
105108 access_stmt = set_tag_access_rights_stmt (
106109 tag_id = tag .id ,
107- user_id = self . user_id ,
110+ user_id = user_id ,
108111 read = read ,
109112 write = write ,
110113 delete = delete ,
111114 )
112115 result = await conn .execute (access_stmt )
113- access = await result .first ()
116+ access = result .first ()
114117 assert access
115118
116119 return TagDict (itertools .chain (tag .items (), access .items ())) # type: ignore
117120
118- async def list_all (self , conn : SAConnection ) -> list [TagDict ]:
119- stmt_list = list_tags_stmt (user_id = self .user_id )
120- return [TagDict (row .items ()) async for row in conn .execute (stmt_list )] # type: ignore
121+ async def list_all (
122+ self ,
123+ connection : AsyncConnection | None = None ,
124+ * ,
125+ user_id : int ,
126+ ) -> list [TagDict ]:
127+ async with get_or_create_connection (self .engine , connection ) as conn :
128+ stmt_list = list_tags_stmt (user_id = user_id )
129+ return [TagDict (row .items ()) async for row in conn .execute (stmt_list )] # type: ignore
121130
122- async def get (self , conn : SAConnection , tag_id : int ) -> TagDict :
123- stmt_get = get_tag_stmt (user_id = self .user_id , tag_id = tag_id )
124- result = await conn .execute (stmt_get )
125- row = await result .first ()
126- if not row :
127- msg = f"{ tag_id = } not found: either no access or does not exists"
128- raise TagNotFoundError (msg )
129- return TagDict (row .items ()) # type: ignore
131+ async def get (
132+ self ,
133+ connection : AsyncConnection | None = None ,
134+ * ,
135+ user_id : int ,
136+ tag_id : int ,
137+ ) -> TagDict :
138+ stmt_get = get_tag_stmt (user_id = user_id , tag_id = tag_id )
139+ async with get_or_create_connection (self .engine , connection ) as conn :
140+ result = await conn .execute (stmt_get )
141+ row = result .first ()
142+ if not row :
143+ msg = f"{ tag_id = } not found: either no access or does not exists"
144+ raise TagNotFoundError (msg )
145+ return TagDict (row .items ()) # type: ignore
130146
131147 async def update (
132148 self ,
133- conn : SAConnection ,
149+ connection : AsyncConnection | None = None ,
150+ * ,
151+ user_id : int ,
134152 tag_id : int ,
135153 ** fields ,
136154 ) -> TagDict :
137- updates = {
138- name : value
139- for name , value in fields .items ()
140- if name in {"name" , "color" , "description" }
141- }
142-
143- if not updates :
144- # no updates == get
145- return await self .get (conn , tag_id = tag_id )
146-
147- update_stmt = update_tag_stmt (user_id = self .user_id , tag_id = tag_id , ** updates )
148- result = await conn .execute (update_stmt )
149- row = await result .first ()
150- if not row :
151- msg = f"{ tag_id = } not updated: either no access or not found"
152- raise TagOperationNotAllowedError (msg )
153-
154- return TagDict (row .items ()) # type: ignore
155-
156- async def delete (self , conn : SAConnection , tag_id : int ) -> None :
157- stmt_delete = delete_tag_stmt (user_id = self .user_id , tag_id = tag_id )
158-
159- deleted = await conn .scalar (stmt_delete )
160- if not deleted :
161- msg = f"Could not delete { tag_id = } . Not found or insuficient access."
162- raise TagOperationNotAllowedError (msg )
155+ async with get_or_create_connection (self .engine , connection ) as conn :
156+ updates = {
157+ name : value
158+ for name , value in fields .items ()
159+ if name in {"name" , "color" , "description" }
160+ }
161+
162+ if not updates :
163+ # no updates == get
164+ return await self .get (conn , user_id = user_id , tag_id = tag_id )
165+
166+ update_stmt = update_tag_stmt (user_id = user_id , tag_id = tag_id , ** updates )
167+ result = await conn .execute (update_stmt )
168+ row = result .first ()
169+ if not row :
170+ msg = f"{ tag_id = } not updated: either no access or not found"
171+ raise TagOperationNotAllowedError (msg )
172+
173+ return TagDict (row .items ()) # type: ignore
174+
175+ async def delete (
176+ self ,
177+ connection : AsyncConnection | None = None ,
178+ * ,
179+ user_id : int ,
180+ tag_id : int ,
181+ ) -> None :
182+ stmt_delete = delete_tag_stmt (user_id = user_id , tag_id = tag_id )
183+ async with get_or_create_connection (self .engine , connection ) as conn :
184+ deleted = await conn .scalar (stmt_delete )
185+ if not deleted :
186+ msg = f"Could not delete { tag_id = } . Not found or insuficient access."
187+ raise TagOperationNotAllowedError (msg )
0 commit comments