66"""
77
88import logging
9- from collections import deque
109from typing import Any , NamedTuple , TypedDict
1110
1211import simcore_postgres_database .errors as db_errors
2726from simcore_postgres_database .utils_groups_extra_properties import (
2827 GroupExtraPropertiesNotFoundError ,
2928)
29+ from simcore_postgres_database .utils_repos import (
30+ pass_or_acquire_connection ,
31+ transaction_context ,
32+ )
3033from simcore_postgres_database .utils_users import generate_alternative_username
3134
32- from ..db .plugin import get_database_engine
33- from ..db .plugin import get_asyncpg_engine , get_database_engine
35+ from ..db .plugin import get_asyncpg_engine
3436from ..login .storage import AsyncpgStorage , get_plugin_storage
3537from ..security .api import clean_auth_policy_cache
3638from . import _users_repository
@@ -81,23 +83,19 @@ def _parse_as_user(user_id: Any) -> UserID:
8183
8284
8385async def get_user_profile (
84- app : web .Application , user_id : UserID , product_name : ProductName
86+ app : web .Application , * , user_id : UserID , product_name : ProductName
8587) -> MyProfileGet :
8688 """
8789 :raises UserNotFoundError:
8890 :raises MissingGroupExtraPropertiesForProductError: when product is not properly configured
8991 """
90-
91- engine = get_database_engine (app )
9292 user_profile : dict [str , Any ] = {}
9393 user_primary_group = everyone_group = {}
9494 user_standard_groups = []
9595 user_id = _parse_as_user (user_id )
9696
97- async with engine .acquire () as conn :
98- row : RowProxy
99-
100- async for row in conn .execute (
97+ async with pass_or_acquire_connection (engine = get_asyncpg_engine (app )) as conn :
98+ result = await conn .stream (
10199 sa .select (users , groups , user_to_groups .c .access_rights )
102100 .select_from (
103101 users .join (user_to_groups , users .c .id == user_to_groups .c .uid ).join (
@@ -107,7 +105,9 @@ async def get_user_profile(
107105 .where (users .c .id == user_id )
108106 .order_by (sa .asc (groups .c .name ))
109107 .set_label_style (sa .LABEL_STYLE_TABLENAME_PLUS_COL )
110- ):
108+ )
109+
110+ async for row in result :
111111 if not user_profile :
112112 user_profile = {
113113 "id" : row .users_id ,
@@ -198,13 +198,12 @@ async def update_user_profile(
198198 user_id = _parse_as_user (user_id )
199199
200200 if updated_values := ToUserUpdateDB .from_api (update ).to_db ():
201- async with get_database_engine (app ).acquire () as conn :
201+
202+ async with transaction_context (engine = get_asyncpg_engine (app )) as conn :
202203 query = users .update ().where (users .c .id == user_id ).values (** updated_values )
203204
204205 try :
205-
206- resp = await conn .execute (query )
207- assert resp .rowcount == 1 # nosec
206+ await conn .execute (query )
208207
209208 except db_errors .UniqueViolation as err :
210209 user_name = updated_values .get ("name" )
@@ -217,15 +216,14 @@ async def update_user_profile(
217216 ) from err
218217
219218
220- async def get_user_role (app : web .Application , user_id : UserID ) -> UserRole :
219+ async def get_user_role (app : web .Application , * , user_id : UserID ) -> UserRole :
221220 """
222221 :raises UserNotFoundError:
223222 """
224223 user_id = _parse_as_user (user_id )
225224
226- engine = get_database_engine (app )
227- async with engine .acquire () as conn :
228- user_role : RowProxy | None = await conn .scalar (
225+ async with pass_or_acquire_connection (engine = get_asyncpg_engine (app )) as conn :
226+ user_role = await conn .scalar (
229227 sa .select (users .c .role ).where (users .c .id == user_id )
230228 )
231229 if user_role is None :
@@ -288,14 +286,11 @@ async def get_user_display_and_id_names(
288286
289287
290288async def get_guest_user_ids_and_names (app : web .Application ) -> list [tuple [int , str ]]:
291- engine = get_database_engine (app )
292- result : deque = deque ()
293- async with engine .acquire () as conn :
294- async for row in conn .execute (
289+ async with pass_or_acquire_connection (engine = get_asyncpg_engine (app )) as conn :
290+ result = await conn .stream (
295291 sa .select (users .c .id , users .c .name ).where (users .c .role == UserRole .GUEST )
296- ):
297- result .append (row .as_tuple ())
298- return list (result )
292+ )
293+ return [(row .id , row .name ) async for row in result ]
299294
300295
301296async def delete_user_without_projects (app : web .Application , user_id : UserID ) -> None :
@@ -304,6 +299,7 @@ async def delete_user_without_projects(app: web.Application, user_id: UserID) ->
304299 # otherwise this function will raise asyncpg.exceptions.ForeignKeyViolationError
305300 # Consider "marking" users as deleted and havning a background job that
306301 # cleans it up
302+ # TODO: upgrade!!!
307303 db : AsyncpgStorage = get_plugin_storage (app )
308304 user = await db .get_user ({"id" : user_id })
309305 if not user :
@@ -330,8 +326,8 @@ async def get_user_fullname(app: web.Application, user_id: UserID) -> FullNameDi
330326 """
331327 user_id = _parse_as_user (user_id )
332328
333- async with get_database_engine ( app ). acquire ( ) as conn :
334- result = await conn .execute (
329+ async with pass_or_acquire_connection ( engine = get_asyncpg_engine ( app )) as conn :
330+ result = await conn .stream (
335331 sa .select (users .c .first_name , users .c .last_name ).where (
336332 users .c .id == user_id
337333 )
@@ -353,12 +349,11 @@ async def get_user(app: web.Application, user_id: UserID) -> dict[str, Any]:
353349 row = await _users_repository .get_user_or_raise (
354350 engine = get_asyncpg_engine (app ), user_id = user_id
355351 )
356- return dict ( row )
352+ return row . _asdict ( )
357353
358354
359355async def get_user_id_from_gid (app : web .Application , primary_gid : int ) -> UserID :
360- engine = get_database_engine (app )
361- async with engine .acquire () as conn :
356+ async with pass_or_acquire_connection (engine = get_asyncpg_engine (app )) as conn :
362357 user_id : UserID = await conn .scalar (
363358 sa .select (users .c .id ).where (users .c .primary_gid == primary_gid )
364359 )
0 commit comments