66"""
77
88import logging
9- from collections import deque
109from typing import Any , NamedTuple , TypedDict
1110
1211import simcore_postgres_database .errors as db_errors
2827from simcore_postgres_database .utils_groups_extra_properties import (
2928 GroupExtraPropertiesNotFoundError ,
3029)
30+ from simcore_postgres_database .utils_repos import (
31+ pass_or_acquire_connection ,
32+ transaction_context ,
33+ )
3134from simcore_postgres_database .utils_users import generate_alternative_username
3235
33- from ..db .plugin import get_database_engine
34- from ..db .plugin import get_asyncpg_engine , get_database_engine
36+ from ..db .plugin import get_asyncpg_engine
3537from ..login .storage import AsyncpgStorage , get_plugin_storage
3638from ..security .api import clean_auth_policy_cache
3739from . import _users_repository
@@ -82,23 +84,19 @@ def _parse_as_user(user_id: Any) -> UserID:
8284
8385
8486async def get_user_profile (
85- app : web .Application , user_id : UserID , product_name : ProductName
87+ app : web .Application , * , user_id : UserID , product_name : ProductName
8688) -> MyProfileGet :
8789 """
8890 :raises UserNotFoundError:
8991 :raises MissingGroupExtraPropertiesForProductError: when product is not properly configured
9092 """
91-
92- engine = get_database_engine (app )
9393 user_profile : dict [str , Any ] = {}
9494 user_primary_group = everyone_group = {}
9595 user_standard_groups = []
9696 user_id = _parse_as_user (user_id )
9797
98- async with engine .acquire () as conn :
99- row : RowProxy
100-
101- async for row in conn .execute (
98+ async with pass_or_acquire_connection (engine = get_asyncpg_engine (app )) as conn :
99+ result = await conn .stream (
102100 sa .select (users , groups , user_to_groups .c .access_rights )
103101 .select_from (
104102 users .join (user_to_groups , users .c .id == user_to_groups .c .uid ).join (
@@ -108,7 +106,9 @@ async def get_user_profile(
108106 .where (users .c .id == user_id )
109107 .order_by (sa .asc (groups .c .name ))
110108 .set_label_style (sa .LABEL_STYLE_TABLENAME_PLUS_COL )
111- ):
109+ )
110+
111+ async for row in result :
112112 if not user_profile :
113113 user_profile = {
114114 "id" : row .users_id ,
@@ -199,13 +199,12 @@ async def update_user_profile(
199199 user_id = _parse_as_user (user_id )
200200
201201 if updated_values := ToUserUpdateDB .from_api (update ).to_db ():
202- async with get_database_engine (app ).acquire () as conn :
202+
203+ async with transaction_context (engine = get_asyncpg_engine (app )) as conn :
203204 query = users .update ().where (users .c .id == user_id ).values (** updated_values )
204205
205206 try :
206-
207- resp = await conn .execute (query )
208- assert resp .rowcount == 1 # nosec
207+ await conn .execute (query )
209208
210209 except db_errors .UniqueViolation as err :
211210 user_name = updated_values .get ("name" )
@@ -218,15 +217,14 @@ async def update_user_profile(
218217 ) from err
219218
220219
221- async def get_user_role (app : web .Application , user_id : UserID ) -> UserRole :
220+ async def get_user_role (app : web .Application , * , user_id : UserID ) -> UserRole :
222221 """
223222 :raises UserNotFoundError:
224223 """
225224 user_id = _parse_as_user (user_id )
226225
227- engine = get_database_engine (app )
228- async with engine .acquire () as conn :
229- user_role : RowProxy | None = await conn .scalar (
226+ async with pass_or_acquire_connection (engine = get_asyncpg_engine (app )) as conn :
227+ user_role = await conn .scalar (
230228 sa .select (users .c .role ).where (users .c .id == user_id )
231229 )
232230 if user_role is None :
@@ -289,14 +287,11 @@ async def get_user_display_and_id_names(
289287
290288
291289async def get_guest_user_ids_and_names (app : web .Application ) -> list [tuple [int , str ]]:
292- engine = get_database_engine (app )
293- result : deque = deque ()
294- async with engine .acquire () as conn :
295- async for row in conn .execute (
290+ async with pass_or_acquire_connection (engine = get_asyncpg_engine (app )) as conn :
291+ result = await conn .stream (
296292 sa .select (users .c .id , users .c .name ).where (users .c .role == UserRole .GUEST )
297- ):
298- result .append (row .as_tuple ())
299- return list (result )
293+ )
294+ return [(row .id , row .name ) async for row in result ]
300295
301296
302297async def delete_user_without_projects (app : web .Application , user_id : UserID ) -> None :
@@ -305,6 +300,7 @@ async def delete_user_without_projects(app: web.Application, user_id: UserID) ->
305300 # otherwise this function will raise asyncpg.exceptions.ForeignKeyViolationError
306301 # Consider "marking" users as deleted and havning a background job that
307302 # cleans it up
303+ # TODO: upgrade!!!
308304 db : AsyncpgStorage = get_plugin_storage (app )
309305 user = await db .get_user ({"id" : user_id })
310306 if not user :
@@ -331,8 +327,8 @@ async def get_user_fullname(app: web.Application, user_id: UserID) -> FullNameDi
331327 """
332328 user_id = _parse_as_user (user_id )
333329
334- async with get_database_engine ( app ). acquire ( ) as conn :
335- result = await conn .execute (
330+ async with pass_or_acquire_connection ( engine = get_asyncpg_engine ( app )) as conn :
331+ result = await conn .stream (
336332 sa .select (users .c .first_name , users .c .last_name ).where (
337333 users .c .id == user_id
338334 )
@@ -354,12 +350,11 @@ async def get_user(app: web.Application, user_id: UserID) -> dict[str, Any]:
354350 row = await _users_repository .get_user_or_raise (
355351 engine = get_asyncpg_engine (app ), user_id = user_id
356352 )
357- return dict ( row )
353+ return row . _asdict ( )
358354
359355
360356async def get_user_id_from_gid (app : web .Application , primary_gid : int ) -> UserID :
361- engine = get_database_engine (app )
362- async with engine .acquire () as conn :
357+ async with pass_or_acquire_connection (engine = get_asyncpg_engine (app )) as conn :
363358 user_id : UserID = await conn .scalar (
364359 sa .select (users .c .id ).where (users .c .primary_gid == primary_gid )
365360 )
0 commit comments