1313from simcore_postgres_database .models .workspaces_access_rights import (
1414 workspaces_access_rights ,
1515)
16+ from simcore_postgres_database .utils_repos import (
17+ pass_or_acquire_connection ,
18+ transaction_context ,
19+ )
1620from sqlalchemy import func , literal_column
21+ from sqlalchemy .ext .asyncio import AsyncConnection
1722from sqlalchemy .sql import select
1823
19- from ..db .plugin import get_database_engine
24+ from ..db .plugin import get_asyncpg_engine
2025from .errors import WorkspaceGroupNotFoundError
2126
2227_logger = logging .getLogger (__name__ )
@@ -41,14 +46,15 @@ class Config:
4146
4247async def create_workspace_group (
4348 app : web .Application ,
49+ connection : AsyncConnection | None = None ,
50+ * ,
4451 workspace_id : WorkspaceID ,
4552 group_id : GroupID ,
46- * ,
4753 read : bool ,
4854 write : bool ,
4955 delete : bool ,
5056) -> WorkspaceGroupGetDB :
51- async with get_database_engine ( app ). acquire ( ) as conn :
57+ async with transaction_context ( get_asyncpg_engine ( app ), connection ) as conn :
5258 result = await conn .execute (
5359 workspaces_access_rights .insert ()
5460 .values (
@@ -62,12 +68,14 @@ async def create_workspace_group(
6268 )
6369 .returning (literal_column ("*" ))
6470 )
65- row = await result .first ()
71+ row = result .first ()
6672 return WorkspaceGroupGetDB .from_orm (row )
6773
6874
6975async def list_workspace_groups (
7076 app : web .Application ,
77+ connection : AsyncConnection | None = None ,
78+ * ,
7179 workspace_id : WorkspaceID ,
7280) -> list [WorkspaceGroupGetDB ]:
7381 stmt = (
@@ -83,14 +91,16 @@ async def list_workspace_groups(
8391 .where (workspaces_access_rights .c .workspace_id == workspace_id )
8492 )
8593
86- async with get_database_engine ( app ). acquire ( ) as conn :
94+ async with pass_or_acquire_connection ( get_asyncpg_engine ( app ), connection ) as conn :
8795 result = await conn .execute (stmt )
88- rows = await result .fetchall () or []
96+ rows = result .fetchall () or []
8997 return [WorkspaceGroupGetDB .from_orm (row ) for row in rows ]
9098
9199
92100async def get_workspace_group (
93101 app : web .Application ,
102+ connection : AsyncConnection | None = None ,
103+ * ,
94104 workspace_id : WorkspaceID ,
95105 group_id : GroupID ,
96106) -> WorkspaceGroupGetDB :
@@ -110,9 +120,9 @@ async def get_workspace_group(
110120 )
111121 )
112122
113- async with get_database_engine ( app ). acquire ( ) as conn :
123+ async with pass_or_acquire_connection ( get_asyncpg_engine ( app ), connection ) as conn :
114124 result = await conn .execute (stmt )
115- row = await result .first ()
125+ row = result .first ()
116126 if row is None :
117127 raise WorkspaceGroupNotFoundError (
118128 workspace_id = workspace_id , group_id = group_id
@@ -122,14 +132,15 @@ async def get_workspace_group(
122132
123133async def update_workspace_group (
124134 app : web .Application ,
135+ connection : AsyncConnection | None = None ,
136+ * ,
125137 workspace_id : WorkspaceID ,
126138 group_id : GroupID ,
127- * ,
128139 read : bool ,
129140 write : bool ,
130141 delete : bool ,
131142) -> WorkspaceGroupGetDB :
132- async with get_database_engine ( app ). acquire ( ) as conn :
143+ async with transaction_context ( get_asyncpg_engine ( app ), connection ) as conn :
133144 result = await conn .execute (
134145 workspaces_access_rights .update ()
135146 .values (
@@ -143,7 +154,7 @@ async def update_workspace_group(
143154 )
144155 .returning (literal_column ("*" ))
145156 )
146- row = await result .first ()
157+ row = result .first ()
147158 if row is None :
148159 raise WorkspaceGroupNotFoundError (
149160 workspace_id = workspace_id , group_id = group_id
@@ -153,10 +164,12 @@ async def update_workspace_group(
153164
154165async def delete_workspace_group (
155166 app : web .Application ,
167+ connection : AsyncConnection | None = None ,
168+ * ,
156169 workspace_id : WorkspaceID ,
157170 group_id : GroupID ,
158171) -> None :
159- async with get_database_engine ( app ). acquire ( ) as conn :
172+ async with transaction_context ( get_asyncpg_engine ( app ), connection ) as conn :
160173 await conn .execute (
161174 workspaces_access_rights .delete ().where (
162175 (workspaces_access_rights .c .workspace_id == workspace_id )
0 commit comments