Skip to content

Commit 4eab7cd

Browse files
committed
tags repo uses new helpers
1 parent 34ac81b commit 4eab7cd

File tree

2 files changed

+242
-109
lines changed

2 files changed

+242
-109
lines changed
Lines changed: 118 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
""" Repository pattern, errors and data structures for models.tags
22
"""
33

4-
import itertools
5-
from dataclasses import dataclass
64
from typing import TypedDict
75

8-
from aiopg.sa.connection import SAConnection
6+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
97

8+
from .base_repo import pass_or_acquire_connection, transaction_context
109
from .utils_tags_sql import (
1110
count_users_with_access_rights_stmt,
1211
create_tag_stmt,
@@ -49,15 +48,16 @@ class TagDict(TypedDict, total=True):
4948
delete: bool
5049

5150

52-
@dataclass(frozen=True)
5351
class TagsRepo:
54-
user_id: int # Determines access-rights
52+
def __init__(self, engine: AsyncEngine):
53+
self.engine = engine
5554

5655
async def access_count(
5756
self,
58-
conn: SAConnection,
59-
tag_id: int,
57+
connection: AsyncConnection | None = None,
6058
*,
59+
user_id: int,
60+
tag_id: int,
6161
read: bool | None = None,
6262
write: bool | None = None,
6363
delete: bool | None = None,
@@ -66,20 +66,22 @@ async def access_count(
6666
Returns 0 if tag does not match access
6767
Returns >0 if it does and represents the number of groups granting this access to the user
6868
"""
69-
count_stmt = count_users_with_access_rights_stmt(
70-
user_id=self.user_id, tag_id=tag_id, read=read, write=write, delete=delete
71-
)
72-
permissions_count: int | None = await conn.scalar(count_stmt)
73-
return permissions_count if permissions_count else 0
69+
async with pass_or_acquire_connection(self.engine, connection) as conn:
70+
count_stmt = count_users_with_access_rights_stmt(
71+
user_id=user_id, tag_id=tag_id, read=read, write=write, delete=delete
72+
)
73+
permissions_count: int | None = await conn.scalar(count_stmt)
74+
return permissions_count if permissions_count else 0
7475

7576
#
7677
# CRUD operations
7778
#
7879

7980
async def create(
8081
self,
81-
conn: SAConnection,
82+
connection: AsyncConnection | None = None,
8283
*,
84+
user_id: int,
8385
name: str,
8486
color: str,
8587
description: str | None = None, # =nullable
@@ -94,69 +96,127 @@ async def create(
9496
if description:
9597
values["description"] = description
9698

97-
async with conn.begin():
99+
async with transaction_context(self.engine, connection) as conn:
98100
# insert new tag
99101
insert_stmt = create_tag_stmt(**values)
100102
result = await conn.execute(insert_stmt)
101-
tag = await result.first()
103+
tag = result.first()
102104
assert tag # nosec
103105

104106
# take tag ownership
105107
access_stmt = set_tag_access_rights_stmt(
106108
tag_id=tag.id,
107-
user_id=self.user_id,
109+
user_id=user_id,
108110
read=read,
109111
write=write,
110112
delete=delete,
111113
)
112114
result = await conn.execute(access_stmt)
113-
access = await result.first()
114-
assert access
115-
116-
return TagDict(itertools.chain(tag.items(), access.items())) # type: ignore
117-
118-
async def list_all(self, conn: SAConnection) -> list[TagDict]:
119-
stmt_list = list_tags_stmt(user_id=self.user_id)
120-
return [TagDict(row.items()) async for row in conn.execute(stmt_list)] # type: ignore
115+
access = result.first()
116+
assert access # nosec
117+
118+
return TagDict(
119+
id=tag.id,
120+
name=tag.name,
121+
description=tag.description,
122+
color=tag.color,
123+
read=access.read,
124+
write=access.write,
125+
delete=access.delete,
126+
)
121127

122-
async def get(self, conn: SAConnection, tag_id: int) -> TagDict:
123-
stmt_get = get_tag_stmt(user_id=self.user_id, tag_id=tag_id)
124-
result = await conn.execute(stmt_get)
125-
row = await result.first()
126-
if not row:
127-
msg = f"{tag_id=} not found: either no access or does not exists"
128-
raise TagNotFoundError(msg)
129-
return TagDict(row.items()) # type: ignore
128+
async def list_all(
129+
self,
130+
connection: AsyncConnection | None = None,
131+
*,
132+
user_id: int,
133+
) -> list[TagDict]:
134+
async with pass_or_acquire_connection(self.engine, connection) as conn:
135+
stmt_list = list_tags_stmt(user_id=user_id)
136+
result = await conn.stream(stmt_list)
137+
return [
138+
TagDict(
139+
id=row.id,
140+
name=row.name,
141+
description=row.description,
142+
color=row.color,
143+
read=row.read,
144+
write=row.write,
145+
delete=row.delete,
146+
)
147+
async for row in result
148+
]
149+
150+
async def get(
151+
self,
152+
connection: AsyncConnection | None = None,
153+
*,
154+
user_id: int,
155+
tag_id: int,
156+
) -> TagDict:
157+
stmt_get = get_tag_stmt(user_id=user_id, tag_id=tag_id)
158+
async with pass_or_acquire_connection(self.engine, connection) as conn:
159+
result = await conn.execute(stmt_get)
160+
row = result.first()
161+
if not row:
162+
msg = f"{tag_id=} not found: either no access or does not exists"
163+
raise TagNotFoundError(msg)
164+
return TagDict(
165+
id=row.id,
166+
name=row.name,
167+
description=row.description,
168+
color=row.color,
169+
read=row.read,
170+
write=row.write,
171+
delete=row.delete,
172+
)
130173

131174
async def update(
132175
self,
133-
conn: SAConnection,
176+
connection: AsyncConnection | None = None,
177+
*,
178+
user_id: int,
134179
tag_id: int,
135180
**fields,
136181
) -> TagDict:
137-
updates = {
138-
name: value
139-
for name, value in fields.items()
140-
if name in {"name", "color", "description"}
141-
}
142-
143-
if not updates:
144-
# no updates == get
145-
return await self.get(conn, tag_id=tag_id)
146-
147-
update_stmt = update_tag_stmt(user_id=self.user_id, tag_id=tag_id, **updates)
148-
result = await conn.execute(update_stmt)
149-
row = await result.first()
150-
if not row:
151-
msg = f"{tag_id=} not updated: either no access or not found"
152-
raise TagOperationNotAllowedError(msg)
153-
154-
return TagDict(row.items()) # type: ignore
155-
156-
async def delete(self, conn: SAConnection, tag_id: int) -> None:
157-
stmt_delete = delete_tag_stmt(user_id=self.user_id, tag_id=tag_id)
182+
async with transaction_context(self.engine, connection) as conn:
183+
updates = {
184+
name: value
185+
for name, value in fields.items()
186+
if name in {"name", "color", "description"}
187+
}
188+
189+
if not updates:
190+
# no updates == get
191+
return await self.get(conn, user_id=user_id, tag_id=tag_id)
192+
193+
update_stmt = update_tag_stmt(user_id=user_id, tag_id=tag_id, **updates)
194+
result = await conn.execute(update_stmt)
195+
row = result.first()
196+
if not row:
197+
msg = f"{tag_id=} not updated: either no access or not found"
198+
raise TagOperationNotAllowedError(msg)
199+
200+
return TagDict(
201+
id=row.id,
202+
name=row.name,
203+
description=row.description,
204+
color=row.color,
205+
read=row.read,
206+
write=row.write,
207+
delete=row.delete,
208+
)
158209

159-
deleted = await conn.scalar(stmt_delete)
160-
if not deleted:
161-
msg = f"Could not delete {tag_id=}. Not found or insuficient access."
162-
raise TagOperationNotAllowedError(msg)
210+
async def delete(
211+
self,
212+
connection: AsyncConnection | None = None,
213+
*,
214+
user_id: int,
215+
tag_id: int,
216+
) -> None:
217+
stmt_delete = delete_tag_stmt(user_id=user_id, tag_id=tag_id)
218+
async with transaction_context(self.engine, connection) as conn:
219+
deleted = await conn.scalar(stmt_delete)
220+
if not deleted:
221+
msg = f"Could not delete {tag_id=}. Not found or insuficient access."
222+
raise TagOperationNotAllowedError(msg)

0 commit comments

Comments
 (0)