22
33import sqlalchemy as sa
44from aiohttp import web
5- from aiopg .sa .connection import SAConnection
6- from aiopg .sa .engine import Engine
7- from aiopg .sa .result import ResultProxy , RowProxy
85from models_library .users import GroupID , UserBillingDetails , UserID
96from simcore_postgres_database .models .groups import groups , user_to_groups
107from simcore_postgres_database .models .products import products
1613 GroupExtraPropertiesNotFoundError ,
1714 GroupExtraPropertiesRepo ,
1815)
16+ from simcore_postgres_database .utils_repos import (
17+ pass_or_acquire_connection ,
18+ transaction_context ,
19+ )
1920from simcore_postgres_database .utils_users import UsersRepo
2021from simcore_service_webserver .users .exceptions import UserNotFoundError
22+ from sqlalchemy .engine .row import Row
23+ from sqlalchemy .ext .asyncio import AsyncConnection , AsyncEngine
2124
2225from ..db .models import user_to_groups
23- from ..db .plugin import get_database_engine
26+ from ..db .plugin import get_asyncpg_engine
2427from .exceptions import BillingDetailsNotFoundError
2528from .schemas import Permission
2629
2730_ALL = None
2831
2932
3033async def get_user_or_raise (
31- engine : Engine , * , user_id : UserID , return_column_names : list [str ] | None = _ALL
32- ) -> RowProxy :
34+ engine : AsyncEngine ,
35+ connection : AsyncConnection | None = None ,
36+ * ,
37+ user_id : UserID ,
38+ return_column_names : list [str ] | None = _ALL ,
39+ ) -> Row :
3340 if return_column_names == _ALL :
3441 return_column_names = list (users .columns .keys ())
3542
3643 assert return_column_names is not None # nosec
3744 assert set (return_column_names ).issubset (users .columns .keys ()) # nosec
3845
39- async with engine .acquire () as conn :
40- row : RowProxy | None = await (
41- await conn .execute (
42- sa .select (* (users .columns [name ] for name in return_column_names )).where (
43- users .c .id == user_id
44- )
46+ async with pass_or_acquire_connection (engine , connection ) as conn :
47+ result = await conn .stream (
48+ sa .select (* (users .columns [name ] for name in return_column_names )).where (
49+ users .c .id == user_id
4550 )
46- ).first ()
51+ )
52+ row = await result .first ()
4753 if row is None :
4854 raise UserNotFoundError (uid = user_id )
4955 return row
5056
5157
52- async def get_users_ids_in_group (conn : SAConnection , gid : GroupID ) -> set [UserID ]:
53- result : set [UserID ] = set ()
54- query_result = await conn .execute (
55- sa .select (user_to_groups .c .uid ).where (user_to_groups .c .gid == gid )
56- )
57- async for entry in query_result :
58- result .add (entry [0 ])
59- return result
58+ async def get_users_ids_in_group (
59+ engine : AsyncEngine ,
60+ connection : AsyncConnection | None = None ,
61+ * ,
62+ group_id : GroupID ,
63+ ) -> set [UserID ]:
64+ async with pass_or_acquire_connection (engine , connection ) as conn :
65+ result = await conn .stream (
66+ sa .select (user_to_groups .c .uid ).where (user_to_groups .c .gid == group_id )
67+ )
68+ return {row .uid async for row in result }
6069
6170
6271async def list_user_permissions (
63- app : web .Application , * , user_id : UserID , product_name : str
72+ app : web .Application ,
73+ connection : AsyncConnection | None = None ,
74+ * ,
75+ user_id : UserID ,
76+ product_name : str ,
6477) -> list [Permission ]:
6578 override_services_specifications = Permission (
6679 name = "override_services_specifications" ,
6780 allowed = False ,
6881 )
6982 with contextlib .suppress (GroupExtraPropertiesNotFoundError ):
70- async with get_database_engine (app ).acquire () as conn :
83+ async with pass_or_acquire_connection (
84+ get_asyncpg_engine (app ), connection
85+ ) as conn :
7186 user_group_extra_properties = (
87+ # TODO: adapt to asyncpg
7288 await GroupExtraPropertiesRepo .get_aggregated_properties_for_user (
7389 conn , user_id = user_id , product_name = product_name
7490 )
@@ -80,34 +96,43 @@ async def list_user_permissions(
8096 return [override_services_specifications ]
8197
8298
83- async def do_update_expired_users (conn : SAConnection ) -> list [UserID ]:
84- result : ResultProxy = await conn .execute (
85- users .update ()
86- .values (status = UserStatus .EXPIRED )
87- .where (
88- (users .c .expires_at .is_not (None ))
89- & (users .c .status == UserStatus .ACTIVE )
90- & (users .c .expires_at < sa .sql .func .now ())
99+ async def do_update_expired_users (
100+ engine : AsyncEngine ,
101+ connection : AsyncConnection | None = None ,
102+ ) -> list [UserID ]:
103+ async with transaction_context (engine , connection ) as conn :
104+ result = await conn .stream (
105+ users .update ()
106+ .values (status = UserStatus .EXPIRED )
107+ .where (
108+ (users .c .expires_at .is_not (None ))
109+ & (users .c .status == UserStatus .ACTIVE )
110+ & (users .c .expires_at < sa .sql .func .now ())
111+ )
112+ .returning (users .c .id )
91113 )
92- .returning (users .c .id )
93- )
94- if rows := await result .fetchall ():
95- return [r .id for r in rows ]
96- return []
114+ return [row .id async for row in result ]
97115
98116
99117async def update_user_status (
100- engine : Engine , * , user_id : UserID , new_status : UserStatus
118+ engine : AsyncEngine ,
119+ connection : AsyncConnection | None = None ,
120+ * ,
121+ user_id : UserID ,
122+ new_status : UserStatus ,
101123):
102- async with engine . acquire ( ) as conn :
124+ async with transaction_context ( engine , connection ) as conn :
103125 await conn .execute (
104126 users .update ().values (status = new_status ).where (users .c .id == user_id )
105127 )
106128
107129
108130async def search_users_and_get_profile (
109- engine : Engine , * , email_like : str
110- ) -> list [RowProxy ]:
131+ engine : AsyncEngine ,
132+ connection : AsyncConnection | None = None ,
133+ * ,
134+ email_like : str ,
135+ ) -> list [Row ]:
111136
112137 users_alias = sa .alias (users , name = "users_alias" )
113138
@@ -117,7 +142,7 @@ async def search_users_and_get_profile(
117142 .label ("invited_by" )
118143 )
119144
120- async with engine . acquire ( ) as conn :
145+ async with pass_or_acquire_connection ( engine , connection ) as conn :
121146 columns = (
122147 users .c .first_name ,
123148 users .c .last_name ,
@@ -159,12 +184,17 @@ async def search_users_and_get_profile(
159184 .where (users .c .email .like (email_like ))
160185 )
161186
162- result = await conn .execute (sa .union (left_outer_join , right_outer_join ))
163- return await result . fetchall () or [ ]
187+ result = await conn .stream (sa .union (left_outer_join , right_outer_join ))
188+ return [ row async for row in result ]
164189
165190
166- async def get_user_products (engine : Engine , user_id : UserID ) -> list [RowProxy ]:
167- async with engine .acquire () as conn :
191+ async def get_user_products (
192+ engine : AsyncEngine ,
193+ connection : AsyncConnection | None = None ,
194+ * ,
195+ user_id : UserID ,
196+ ) -> list [Row ]:
197+ async with pass_or_acquire_connection (engine , connection ) as conn :
168198 product_name_subq = (
169199 sa .select (products .c .name )
170200 .where (products .c .group_id == groups .c .gid )
@@ -186,14 +216,19 @@ async def get_user_products(engine: Engine, user_id: UserID) -> list[RowProxy]:
186216 .where (users .c .id == user_id )
187217 .order_by (groups .c .gid )
188218 )
189- result = await conn .execute (query )
190- return await result . fetchall () or [ ]
219+ result = await conn .stream (query )
220+ return [ row async for row in result ]
191221
192222
193223async def new_user_details (
194- engine : Engine , email : str , created_by : UserID , ** other_values
224+ engine : AsyncEngine ,
225+ connection : AsyncConnection | None = None ,
226+ * ,
227+ email : str ,
228+ created_by : UserID ,
229+ ** other_values ,
195230) -> None :
196- async with engine . acquire ( ) as conn :
231+ async with transaction_context ( engine , connection ) as conn :
197232 await conn .execute (
198233 sa .insert (users_pre_registration_details ).values (
199234 created_by = created_by , pre_email = email , ** other_values
@@ -202,13 +237,13 @@ async def new_user_details(
202237
203238
204239async def get_user_billing_details (
205- engine : Engine , user_id : UserID
240+ engine : AsyncEngine , connection : AsyncConnection | None = None , * , user_id : UserID
206241) -> UserBillingDetails :
207242 """
208243 Raises:
209244 BillingDetailsNotFoundError
210245 """
211- async with engine . acquire ( ) as conn :
246+ async with pass_or_acquire_connection ( engine , connection ) as conn :
212247 user_billing_details = await UsersRepo .get_billing_details (conn , user_id )
213248 if not user_billing_details :
214249 raise BillingDetailsNotFoundError (user_id = user_id )
0 commit comments