22
33import sqlalchemy as sa
44from aiohttp import web
5- from aiopg .sa .connection import SAConnection
6- from aiopg .sa .engine import Engine
7- from aiopg .sa .result import ResultProxy , RowProxy
8- from models_library .groups import GroupID
9- from models_library .users import UserBillingDetails , UserID
5+ from models_library .users import GroupID , UserBillingDetails , UserID
106from simcore_postgres_database .models .groups import groups , user_to_groups
117from simcore_postgres_database .models .products import products
128from simcore_postgres_database .models .users import UserStatus , users
1713 GroupExtraPropertiesNotFoundError ,
1814 GroupExtraPropertiesRepo ,
1915)
16+ from simcore_postgres_database .utils_repos import (
17+ pass_or_acquire_connection ,
18+ transaction_context ,
19+ )
2020from simcore_postgres_database .utils_users import UsersRepo
2121from simcore_service_webserver .users .exceptions import UserNotFoundError
22+ from sqlalchemy .engine .row import Row
23+ from sqlalchemy .ext .asyncio import AsyncConnection , AsyncEngine
2224
2325from ..db .models import user_to_groups
24- from ..db .plugin import get_database_engine
26+ from ..db .plugin import get_asyncpg_engine
2527from .exceptions import BillingDetailsNotFoundError
2628from .schemas import Permission
2729
2830_ALL = None
2931
3032
3133async def get_user_or_raise (
32- engine : Engine , * , user_id : UserID , return_column_names : list [str ] | None = _ALL
33- ) -> RowProxy :
34+ engine : AsyncEngine ,
35+ connection : AsyncConnection | None = None ,
36+ * ,
37+ user_id : UserID ,
38+ return_column_names : list [str ] | None = _ALL ,
39+ ) -> Row :
3440 if return_column_names == _ALL :
3541 return_column_names = list (users .columns .keys ())
3642
3743 assert return_column_names is not None # nosec
3844 assert set (return_column_names ).issubset (users .columns .keys ()) # nosec
3945
40- async with engine .acquire () as conn :
41- row : RowProxy | None = await (
42- await conn .execute (
43- sa .select (* (users .columns [name ] for name in return_column_names )).where (
44- users .c .id == user_id
45- )
46+ async with pass_or_acquire_connection (engine , connection ) as conn :
47+ result = await conn .stream (
48+ sa .select (* (users .columns [name ] for name in return_column_names )).where (
49+ users .c .id == user_id
4650 )
47- ).first ()
51+ )
52+ row = await result .first ()
4853 if row is None :
4954 raise UserNotFoundError (uid = user_id )
5055 return row
5156
5257
53- async def get_users_ids_in_group (conn : SAConnection , gid : GroupID ) -> set [UserID ]:
54- result : set [UserID ] = set ()
55- query_result = await conn .execute (
56- sa .select (user_to_groups .c .uid ).where (user_to_groups .c .gid == gid )
57- )
58- async for entry in query_result :
59- result .add (entry [0 ])
60- return result
58+ async def get_users_ids_in_group (
59+ engine : AsyncEngine ,
60+ connection : AsyncConnection | None = None ,
61+ * ,
62+ group_id : GroupID ,
63+ ) -> set [UserID ]:
64+ async with pass_or_acquire_connection (engine , connection ) as conn :
65+ result = await conn .stream (
66+ sa .select (user_to_groups .c .uid ).where (user_to_groups .c .gid == group_id )
67+ )
68+ return {row .uid async for row in result }
6169
6270
6371async def list_user_permissions (
64- app : web .Application , * , user_id : UserID , product_name : str
72+ app : web .Application ,
73+ connection : AsyncConnection | None = None ,
74+ * ,
75+ user_id : UserID ,
76+ product_name : str ,
6577) -> list [Permission ]:
6678 override_services_specifications = Permission (
6779 name = "override_services_specifications" ,
6880 allowed = False ,
6981 )
7082 with contextlib .suppress (GroupExtraPropertiesNotFoundError ):
71- async with get_database_engine (app ).acquire () as conn :
83+ async with pass_or_acquire_connection (
84+ get_asyncpg_engine (app ), connection
85+ ) as conn :
7286 user_group_extra_properties = (
87+ # TODO: adapt to asyncpg
7388 await GroupExtraPropertiesRepo .get_aggregated_properties_for_user (
7489 conn , user_id = user_id , product_name = product_name
7590 )
@@ -81,34 +96,43 @@ async def list_user_permissions(
8196 return [override_services_specifications ]
8297
8398
84- async def do_update_expired_users (conn : SAConnection ) -> list [UserID ]:
85- result : ResultProxy = await conn .execute (
86- users .update ()
87- .values (status = UserStatus .EXPIRED )
88- .where (
89- (users .c .expires_at .is_not (None ))
90- & (users .c .status == UserStatus .ACTIVE )
91- & (users .c .expires_at < sa .sql .func .now ())
99+ async def do_update_expired_users (
100+ engine : AsyncEngine ,
101+ connection : AsyncConnection | None = None ,
102+ ) -> list [UserID ]:
103+ async with transaction_context (engine , connection ) as conn :
104+ result = await conn .stream (
105+ users .update ()
106+ .values (status = UserStatus .EXPIRED )
107+ .where (
108+ (users .c .expires_at .is_not (None ))
109+ & (users .c .status == UserStatus .ACTIVE )
110+ & (users .c .expires_at < sa .sql .func .now ())
111+ )
112+ .returning (users .c .id )
92113 )
93- .returning (users .c .id )
94- )
95- if rows := await result .fetchall ():
96- return [r .id for r in rows ]
97- return []
114+ return [row .id async for row in result ]
98115
99116
100117async def update_user_status (
101- engine : Engine , * , user_id : UserID , new_status : UserStatus
118+ engine : AsyncEngine ,
119+ connection : AsyncConnection | None = None ,
120+ * ,
121+ user_id : UserID ,
122+ new_status : UserStatus ,
102123):
103- async with engine . acquire ( ) as conn :
124+ async with transaction_context ( engine , connection ) as conn :
104125 await conn .execute (
105126 users .update ().values (status = new_status ).where (users .c .id == user_id )
106127 )
107128
108129
109130async def search_users_and_get_profile (
110- engine : Engine , * , email_like : str
111- ) -> list [RowProxy ]:
131+ engine : AsyncEngine ,
132+ connection : AsyncConnection | None = None ,
133+ * ,
134+ email_like : str ,
135+ ) -> list [Row ]:
112136
113137 users_alias = sa .alias (users , name = "users_alias" )
114138
@@ -118,7 +142,7 @@ async def search_users_and_get_profile(
118142 .label ("invited_by" )
119143 )
120144
121- async with engine . acquire ( ) as conn :
145+ async with pass_or_acquire_connection ( engine , connection ) as conn :
122146 columns = (
123147 users .c .first_name ,
124148 users .c .last_name ,
@@ -160,12 +184,17 @@ async def search_users_and_get_profile(
160184 .where (users .c .email .like (email_like ))
161185 )
162186
163- result = await conn .execute (sa .union (left_outer_join , right_outer_join ))
164- return await result . fetchall () or [ ]
187+ result = await conn .stream (sa .union (left_outer_join , right_outer_join ))
188+ return [ row async for row in result ]
165189
166190
167- async def get_user_products (engine : Engine , user_id : UserID ) -> list [RowProxy ]:
168- async with engine .acquire () as conn :
191+ async def get_user_products (
192+ engine : AsyncEngine ,
193+ connection : AsyncConnection | None = None ,
194+ * ,
195+ user_id : UserID ,
196+ ) -> list [Row ]:
197+ async with pass_or_acquire_connection (engine , connection ) as conn :
169198 product_name_subq = (
170199 sa .select (products .c .name )
171200 .where (products .c .group_id == groups .c .gid )
@@ -187,14 +216,19 @@ async def get_user_products(engine: Engine, user_id: UserID) -> list[RowProxy]:
187216 .where (users .c .id == user_id )
188217 .order_by (groups .c .gid )
189218 )
190- result = await conn .execute (query )
191- return await result . fetchall () or [ ]
219+ result = await conn .stream (query )
220+ return [ row async for row in result ]
192221
193222
194223async def new_user_details (
195- engine : Engine , email : str , created_by : UserID , ** other_values
224+ engine : AsyncEngine ,
225+ connection : AsyncConnection | None = None ,
226+ * ,
227+ email : str ,
228+ created_by : UserID ,
229+ ** other_values ,
196230) -> None :
197- async with engine . acquire ( ) as conn :
231+ async with transaction_context ( engine , connection ) as conn :
198232 await conn .execute (
199233 sa .insert (users_pre_registration_details ).values (
200234 created_by = created_by , pre_email = email , ** other_values
@@ -203,13 +237,13 @@ async def new_user_details(
203237
204238
205239async def get_user_billing_details (
206- engine : Engine , user_id : UserID
240+ engine : AsyncEngine , connection : AsyncConnection | None = None , * , user_id : UserID
207241) -> UserBillingDetails :
208242 """
209243 Raises:
210244 BillingDetailsNotFoundError
211245 """
212- async with engine . acquire ( ) as conn :
246+ async with pass_or_acquire_connection ( engine , connection ) as conn :
213247 user_billing_details = await UsersRepo .get_billing_details (conn , user_id )
214248 if not user_billing_details :
215249 raise BillingDetailsNotFoundError (user_id = user_id )
0 commit comments