Skip to content

Commit 1e7fc8a

Browse files
committed
merged
1 parent a053a3d commit 1e7fc8a

File tree

2 files changed

+79
-241
lines changed

2 files changed

+79
-241
lines changed

packages/postgres-database/src/simcore_postgres_database/tags_repo.py

Lines changed: 0 additions & 187 deletions
This file was deleted.

packages/postgres-database/src/simcore_postgres_database/utils_tags.py

Lines changed: 79 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
"""
33

44
import itertools
5-
from dataclasses import dataclass
65
from typing import TypedDict
76

8-
from aiopg.sa.connection import SAConnection
7+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
98

9+
from .base_repo import get_or_create_connection, transaction_context
1010
from .utils_tags_sql import (
1111
count_users_with_access_rights_stmt,
1212
create_tag_stmt,
@@ -49,15 +49,16 @@ class TagDict(TypedDict, total=True):
4949
delete: bool
5050

5151

52-
@dataclass(frozen=True)
5352
class TagsRepo:
54-
user_id: int # Determines access-rights
53+
def __init__(self, engine: AsyncEngine):
54+
self.engine = engine
5555

5656
async def access_count(
5757
self,
58-
conn: SAConnection,
59-
tag_id: int,
58+
connection: AsyncConnection | None = None,
6059
*,
60+
user_id: int,
61+
tag_id: int,
6162
read: bool | None = None,
6263
write: bool | None = None,
6364
delete: bool | None = None,
@@ -66,20 +67,22 @@ async def access_count(
6667
Returns 0 if tag does not match access
6768
Returns >0 if it does and represents the number of groups granting this access to the user
6869
"""
69-
count_stmt = count_users_with_access_rights_stmt(
70-
user_id=self.user_id, tag_id=tag_id, read=read, write=write, delete=delete
71-
)
72-
permissions_count: int | None = await conn.scalar(count_stmt)
73-
return permissions_count if permissions_count else 0
70+
async with get_or_create_connection(self.engine, connection) as conn:
71+
count_stmt = count_users_with_access_rights_stmt(
72+
user_id=user_id, tag_id=tag_id, read=read, write=write, delete=delete
73+
)
74+
permissions_count: int | None = await conn.scalar(count_stmt)
75+
return permissions_count if permissions_count else 0
7476

7577
#
7678
# CRUD operations
7779
#
7880

7981
async def create(
8082
self,
81-
conn: SAConnection,
83+
connection: AsyncConnection | None = None,
8284
*,
85+
user_id: int,
8386
name: str,
8487
color: str,
8588
description: str | None = None, # =nullable
@@ -94,69 +97,91 @@ async def create(
9497
if description:
9598
values["description"] = description
9699

97-
async with conn.begin():
100+
async with transaction_context(self.engine, connection) as conn:
98101
# insert new tag
99102
insert_stmt = create_tag_stmt(**values)
100103
result = await conn.execute(insert_stmt)
101-
tag = await result.first()
104+
tag = result.first()
102105
assert tag # nosec
103106

104107
# take tag ownership
105108
access_stmt = set_tag_access_rights_stmt(
106109
tag_id=tag.id,
107-
user_id=self.user_id,
110+
user_id=user_id,
108111
read=read,
109112
write=write,
110113
delete=delete,
111114
)
112115
result = await conn.execute(access_stmt)
113-
access = await result.first()
116+
access = result.first()
114117
assert access
115118

116119
return TagDict(itertools.chain(tag.items(), access.items())) # type: ignore
117120

118-
async def list_all(self, conn: SAConnection) -> list[TagDict]:
119-
stmt_list = list_tags_stmt(user_id=self.user_id)
120-
return [TagDict(row.items()) async for row in conn.execute(stmt_list)] # type: ignore
121+
async def list_all(
122+
self,
123+
connection: AsyncConnection | None = None,
124+
*,
125+
user_id: int,
126+
) -> list[TagDict]:
127+
async with get_or_create_connection(self.engine, connection) as conn:
128+
stmt_list = list_tags_stmt(user_id=user_id)
129+
return [TagDict(row.items()) async for row in conn.execute(stmt_list)] # type: ignore
121130

122-
async def get(self, conn: SAConnection, tag_id: int) -> TagDict:
123-
stmt_get = get_tag_stmt(user_id=self.user_id, tag_id=tag_id)
124-
result = await conn.execute(stmt_get)
125-
row = await result.first()
126-
if not row:
127-
msg = f"{tag_id=} not found: either no access or does not exists"
128-
raise TagNotFoundError(msg)
129-
return TagDict(row.items()) # type: ignore
131+
async def get(
132+
self,
133+
connection: AsyncConnection | None = None,
134+
*,
135+
user_id: int,
136+
tag_id: int,
137+
) -> TagDict:
138+
stmt_get = get_tag_stmt(user_id=user_id, tag_id=tag_id)
139+
async with get_or_create_connection(self.engine, connection) as conn:
140+
result = await conn.execute(stmt_get)
141+
row = result.first()
142+
if not row:
143+
msg = f"{tag_id=} not found: either no access or does not exists"
144+
raise TagNotFoundError(msg)
145+
return TagDict(row.items()) # type: ignore
130146

131147
async def update(
132148
self,
133-
conn: SAConnection,
149+
connection: AsyncConnection | None = None,
150+
*,
151+
user_id: int,
134152
tag_id: int,
135153
**fields,
136154
) -> TagDict:
137-
updates = {
138-
name: value
139-
for name, value in fields.items()
140-
if name in {"name", "color", "description"}
141-
}
142-
143-
if not updates:
144-
# no updates == get
145-
return await self.get(conn, tag_id=tag_id)
146-
147-
update_stmt = update_tag_stmt(user_id=self.user_id, tag_id=tag_id, **updates)
148-
result = await conn.execute(update_stmt)
149-
row = await result.first()
150-
if not row:
151-
msg = f"{tag_id=} not updated: either no access or not found"
152-
raise TagOperationNotAllowedError(msg)
153-
154-
return TagDict(row.items()) # type: ignore
155-
156-
async def delete(self, conn: SAConnection, tag_id: int) -> None:
157-
stmt_delete = delete_tag_stmt(user_id=self.user_id, tag_id=tag_id)
158-
159-
deleted = await conn.scalar(stmt_delete)
160-
if not deleted:
161-
msg = f"Could not delete {tag_id=}. Not found or insuficient access."
162-
raise TagOperationNotAllowedError(msg)
155+
async with get_or_create_connection(self.engine, connection) as conn:
156+
updates = {
157+
name: value
158+
for name, value in fields.items()
159+
if name in {"name", "color", "description"}
160+
}
161+
162+
if not updates:
163+
# no updates == get
164+
return await self.get(conn, user_id=user_id, tag_id=tag_id)
165+
166+
update_stmt = update_tag_stmt(user_id=user_id, tag_id=tag_id, **updates)
167+
result = await conn.execute(update_stmt)
168+
row = result.first()
169+
if not row:
170+
msg = f"{tag_id=} not updated: either no access or not found"
171+
raise TagOperationNotAllowedError(msg)
172+
173+
return TagDict(row.items()) # type: ignore
174+
175+
async def delete(
176+
self,
177+
connection: AsyncConnection | None = None,
178+
*,
179+
user_id: int,
180+
tag_id: int,
181+
) -> None:
182+
stmt_delete = delete_tag_stmt(user_id=user_id, tag_id=tag_id)
183+
async with get_or_create_connection(self.engine, connection) as conn:
184+
deleted = await conn.scalar(stmt_delete)
185+
if not deleted:
186+
msg = f"Could not delete {tag_id=}. Not found or insuficient access."
187+
raise TagOperationNotAllowedError(msg)

0 commit comments

Comments
 (0)