11""" Repository pattern, errors and data structures for models.tags
22"""
33
4- import itertools
5- from dataclasses import dataclass
64from typing import TypedDict
75
8- from aiopg . sa . connection import SAConnection
6+ from sqlalchemy . ext . asyncio import AsyncConnection , AsyncEngine
97
8+ from .base_repo import pass_or_acquire_connection , transaction_context
109from .utils_tags_sql import (
1110 count_users_with_access_rights_stmt ,
1211 create_tag_stmt ,
@@ -49,15 +48,16 @@ class TagDict(TypedDict, total=True):
4948 delete : bool
5049
5150
52- @dataclass (frozen = True )
5351class TagsRepo :
54- user_id : int # Determines access-rights
52+ def __init__ (self , engine : AsyncEngine ):
53+ self .engine = engine
5554
5655 async def access_count (
5756 self ,
58- conn : SAConnection ,
59- tag_id : int ,
57+ connection : AsyncConnection | None = None ,
6058 * ,
59+ user_id : int ,
60+ tag_id : int ,
6161 read : bool | None = None ,
6262 write : bool | None = None ,
6363 delete : bool | None = None ,
@@ -66,20 +66,22 @@ async def access_count(
6666 Returns 0 if tag does not match access
6767 Returns >0 if it does and represents the number of groups granting this access to the user
6868 """
69- count_stmt = count_users_with_access_rights_stmt (
70- user_id = self .user_id , tag_id = tag_id , read = read , write = write , delete = delete
71- )
72- permissions_count : int | None = await conn .scalar (count_stmt )
73- return permissions_count if permissions_count else 0
69+ async with pass_or_acquire_connection (self .engine , connection ) as conn :
70+ count_stmt = count_users_with_access_rights_stmt (
71+ user_id = user_id , tag_id = tag_id , read = read , write = write , delete = delete
72+ )
73+ permissions_count : int | None = await conn .scalar (count_stmt )
74+ return permissions_count if permissions_count else 0
7475
7576 #
7677 # CRUD operations
7778 #
7879
7980 async def create (
8081 self ,
81- conn : SAConnection ,
82+ connection : AsyncConnection | None = None ,
8283 * ,
84+ user_id : int ,
8385 name : str ,
8486 color : str ,
8587 description : str | None = None , # =nullable
@@ -94,69 +96,127 @@ async def create(
9496 if description :
9597 values ["description" ] = description
9698
97- async with conn . begin () :
99+ async with transaction_context ( self . engine , connection ) as conn :
98100 # insert new tag
99101 insert_stmt = create_tag_stmt (** values )
100102 result = await conn .execute (insert_stmt )
101- tag = await result .first ()
103+ tag = result .first ()
102104 assert tag # nosec
103105
104106 # take tag ownership
105107 access_stmt = set_tag_access_rights_stmt (
106108 tag_id = tag .id ,
107- user_id = self . user_id ,
109+ user_id = user_id ,
108110 read = read ,
109111 write = write ,
110112 delete = delete ,
111113 )
112114 result = await conn .execute (access_stmt )
113- access = await result .first ()
114- assert access
115-
116- return TagDict (itertools .chain (tag .items (), access .items ())) # type: ignore
117-
118- async def list_all (self , conn : SAConnection ) -> list [TagDict ]:
119- stmt_list = list_tags_stmt (user_id = self .user_id )
120- return [TagDict (row .items ()) async for row in conn .execute (stmt_list )] # type: ignore
115+ access = result .first ()
116+ assert access # nosec
117+
118+ return TagDict (
119+ id = tag .id ,
120+ name = tag .name ,
121+ description = tag .description ,
122+ color = tag .color ,
123+ read = access .read ,
124+ write = access .write ,
125+ delete = access .delete ,
126+ )
121127
122- async def get (self , conn : SAConnection , tag_id : int ) -> TagDict :
123- stmt_get = get_tag_stmt (user_id = self .user_id , tag_id = tag_id )
124- result = await conn .execute (stmt_get )
125- row = await result .first ()
126- if not row :
127- msg = f"{ tag_id = } not found: either no access or does not exists"
128- raise TagNotFoundError (msg )
129- return TagDict (row .items ()) # type: ignore
128+ async def list_all (
129+ self ,
130+ connection : AsyncConnection | None = None ,
131+ * ,
132+ user_id : int ,
133+ ) -> list [TagDict ]:
134+ async with pass_or_acquire_connection (self .engine , connection ) as conn :
135+ stmt_list = list_tags_stmt (user_id = user_id )
136+ result = await conn .stream (stmt_list )
137+ return [
138+ TagDict (
139+ id = row .id ,
140+ name = row .name ,
141+ description = row .description ,
142+ color = row .color ,
143+ read = row .read ,
144+ write = row .write ,
145+ delete = row .delete ,
146+ )
147+ async for row in result
148+ ]
149+
150+ async def get (
151+ self ,
152+ connection : AsyncConnection | None = None ,
153+ * ,
154+ user_id : int ,
155+ tag_id : int ,
156+ ) -> TagDict :
157+ stmt_get = get_tag_stmt (user_id = user_id , tag_id = tag_id )
158+ async with pass_or_acquire_connection (self .engine , connection ) as conn :
159+ result = await conn .execute (stmt_get )
160+ row = result .first ()
161+ if not row :
162+ msg = f"{ tag_id = } not found: either no access or does not exists"
163+ raise TagNotFoundError (msg )
164+ return TagDict (
165+ id = row .id ,
166+ name = row .name ,
167+ description = row .description ,
168+ color = row .color ,
169+ read = row .read ,
170+ write = row .write ,
171+ delete = row .delete ,
172+ )
130173
131174 async def update (
132175 self ,
133- conn : SAConnection ,
176+ connection : AsyncConnection | None = None ,
177+ * ,
178+ user_id : int ,
134179 tag_id : int ,
135180 ** fields ,
136181 ) -> TagDict :
137- updates = {
138- name : value
139- for name , value in fields .items ()
140- if name in {"name" , "color" , "description" }
141- }
142-
143- if not updates :
144- # no updates == get
145- return await self .get (conn , tag_id = tag_id )
146-
147- update_stmt = update_tag_stmt (user_id = self .user_id , tag_id = tag_id , ** updates )
148- result = await conn .execute (update_stmt )
149- row = await result .first ()
150- if not row :
151- msg = f"{ tag_id = } not updated: either no access or not found"
152- raise TagOperationNotAllowedError (msg )
153-
154- return TagDict (row .items ()) # type: ignore
155-
156- async def delete (self , conn : SAConnection , tag_id : int ) -> None :
157- stmt_delete = delete_tag_stmt (user_id = self .user_id , tag_id = tag_id )
182+ async with transaction_context (self .engine , connection ) as conn :
183+ updates = {
184+ name : value
185+ for name , value in fields .items ()
186+ if name in {"name" , "color" , "description" }
187+ }
188+
189+ if not updates :
190+ # no updates == get
191+ return await self .get (conn , user_id = user_id , tag_id = tag_id )
192+
193+ update_stmt = update_tag_stmt (user_id = user_id , tag_id = tag_id , ** updates )
194+ result = await conn .execute (update_stmt )
195+ row = result .first ()
196+ if not row :
197+ msg = f"{ tag_id = } not updated: either no access or not found"
198+ raise TagOperationNotAllowedError (msg )
199+
200+ return TagDict (
201+ id = row .id ,
202+ name = row .name ,
203+ description = row .description ,
204+ color = row .color ,
205+ read = row .read ,
206+ write = row .write ,
207+ delete = row .delete ,
208+ )
158209
159- deleted = await conn .scalar (stmt_delete )
160- if not deleted :
161- msg = f"Could not delete { tag_id = } . Not found or insuficient access."
162- raise TagOperationNotAllowedError (msg )
210+ async def delete (
211+ self ,
212+ connection : AsyncConnection | None = None ,
213+ * ,
214+ user_id : int ,
215+ tag_id : int ,
216+ ) -> None :
217+ stmt_delete = delete_tag_stmt (user_id = user_id , tag_id = tag_id )
218+ async with transaction_context (self .engine , connection ) as conn :
219+ deleted = await conn .scalar (stmt_delete )
220+ if not deleted :
221+ msg = f"Could not delete { tag_id = } . Not found or insuficient access."
222+ raise TagOperationNotAllowedError (msg )
0 commit comments