2323from simcore_postgres_database .models .workspaces_access_rights import (
2424 workspaces_access_rights ,
2525)
26- from simcore_postgres_database .utils import assemble_array_groups
26+ from simcore_postgres_database .utils_repos import (
27+ pass_or_acquire_connection ,
28+ transaction_context ,
29+ )
30+ from simcore_postgres_database .utils_sql import assemble_array_groups
2731from sqlalchemy import func
32+ from sqlalchemy .ext .asyncio import AsyncConnection
2833from sqlalchemy .orm import aliased
2934from sqlalchemy .sql import ColumnElement , CompoundSelect , Select , asc , desc , select
3035
31- from ..db .plugin import get_database_engine
36+ from ..db .plugin import get_asyncpg_engine
3237from ..groups .api import list_all_user_groups
3338from .errors import FolderAccessForbiddenError , FolderNotFoundError
3439
@@ -61,6 +66,7 @@ def as_dict_exclude_unset(**params) -> dict[str, Any]:
6166
6267async def create (
6368 app : web .Application ,
69+ connection : AsyncConnection | None = None ,
6470 * ,
6571 created_by_gid : GroupID ,
6672 folder_name : str ,
@@ -73,8 +79,8 @@ async def create(
7379 user_id is not None and workspace_id is not None
7480 ), "Both user_id and workspace_id cannot be provided at the same time. Please provide only one."
7581
76- async with get_database_engine ( app ). acquire ( ) as conn :
77- result = await conn .execute (
82+ async with transaction_context ( get_asyncpg_engine ( app ), connection ) as conn :
83+ result = await conn .stream (
7884 folders_v2 .insert ()
7985 .values (
8086 name = folder_name ,
@@ -94,6 +100,7 @@ async def create(
94100
95101async def list_ ( # pylint: disable=too-many-arguments,too-many-branches
96102 app : web .Application ,
103+ connection : AsyncConnection | None = None ,
97104 * ,
98105 product_name : ProductName ,
99106 user_id : UserID ,
@@ -234,18 +241,17 @@ async def list_( # pylint: disable=too-many-arguments,too-many-branches
234241 )
235242 list_query = list_query .offset (offset ).limit (limit )
236243
237- async with get_database_engine (app ).acquire () as conn :
238- count_result = await conn .execute (count_query )
239- total_count = await count_result .scalar ()
244+ async with pass_or_acquire_connection (get_asyncpg_engine (app ), connection ) as conn :
245+ total_count = await conn .scalar (count_query )
240246
241- result = await conn .execute (list_query )
242- rows = await result .fetchall () or []
243- results : list [FolderDB ] = [FolderDB .from_orm (row ) for row in rows ]
244- return cast (int , total_count ), results
247+ result = await conn .stream (list_query )
248+ folders : list [FolderDB ] = [FolderDB .from_orm (row ) async for row in result ]
249+ return cast (int , total_count ), folders
245250
246251
247252async def get (
248253 app : web .Application ,
254+ connection : AsyncConnection | None = None ,
249255 * ,
250256 folder_id : FolderID ,
251257 product_name : ProductName ,
@@ -259,8 +265,8 @@ async def get(
259265 )
260266 )
261267
262- async with get_database_engine ( app ). acquire ( ) as conn :
263- result = await conn .execute (query )
268+ async with pass_or_acquire_connection ( get_asyncpg_engine ( app ), connection ) as conn :
269+ result = await conn .stream (query )
264270 row = await result .first ()
265271 if row is None :
266272 raise FolderAccessForbiddenError (
@@ -271,6 +277,7 @@ async def get(
271277
272278async def get_for_user_or_workspace (
273279 app : web .Application ,
280+ connection : AsyncConnection | None = None ,
274281 * ,
275282 folder_id : FolderID ,
276283 product_name : ProductName ,
@@ -295,8 +302,8 @@ async def get_for_user_or_workspace(
295302 else :
296303 query = query .where (folders_v2 .c .workspace_id == workspace_id )
297304
298- async with get_database_engine ( app ). acquire ( ) as conn :
299- result = await conn .execute (query )
305+ async with pass_or_acquire_connection ( get_asyncpg_engine ( app ), connection ) as conn :
306+ result = await conn .stream (query )
300307 row = await result .first ()
301308 if row is None :
302309 raise FolderAccessForbiddenError (
@@ -305,8 +312,10 @@ async def get_for_user_or_workspace(
305312 return FolderDB .from_orm (row )
306313
307314
308- async def _update_impl (
315+ async def update (
309316 app : web .Application ,
317+ connection : AsyncConnection | None = None ,
318+ * ,
310319 folders_id_or_ids : FolderID | set [FolderID ],
311320 product_name : ProductName ,
312321 # updatable columns
@@ -339,64 +348,22 @@ async def _update_impl(
339348 # single-update
340349 query = query .where (folders_v2 .c .folder_id == folders_id_or_ids )
341350
342- async with get_database_engine ( app ). acquire ( ) as conn :
343- result = await conn .execute (query )
351+ async with transaction_context ( get_asyncpg_engine ( app ), connection ) as conn :
352+ result = await conn .stream (query )
344353 row = await result .first ()
345354 if row is None :
346355 raise FolderNotFoundError (reason = f"Folder { folders_id_or_ids } not found." )
347356 return FolderDB .from_orm (row )
348357
349358
350- async def update_batch (
351- app : web .Application ,
352- * folder_id : FolderID ,
353- product_name : ProductName ,
354- # updatable columns
355- name : str | UnSet = _unset ,
356- parent_folder_id : FolderID | None | UnSet = _unset ,
357- trashed_at : datetime | None | UnSet = _unset ,
358- trashed_explicitly : bool | UnSet = _unset ,
359- ) -> FolderDB :
360- return await _update_impl (
361- app = app ,
362- folders_id_or_ids = set (folder_id ),
363- product_name = product_name ,
364- name = name ,
365- parent_folder_id = parent_folder_id ,
366- trashed_at = trashed_at ,
367- trashed_explicitly = trashed_explicitly ,
368- )
369-
370-
371- async def update (
372- app : web .Application ,
373- * ,
374- folder_id : FolderID ,
375- product_name : ProductName ,
376- # updatable columns
377- name : str | UnSet = _unset ,
378- parent_folder_id : FolderID | None | UnSet = _unset ,
379- trashed_at : datetime | None | UnSet = _unset ,
380- trashed_explicitly : bool | UnSet = _unset ,
381- ) -> FolderDB :
382- return await _update_impl (
383- app = app ,
384- folders_id_or_ids = folder_id ,
385- product_name = product_name ,
386- name = name ,
387- parent_folder_id = parent_folder_id ,
388- trashed_at = trashed_at ,
389- trashed_explicitly = trashed_explicitly ,
390- )
391-
392-
393359async def delete_recursively (
394360 app : web .Application ,
361+ connection : AsyncConnection | None = None ,
395362 * ,
396363 folder_id : FolderID ,
397364 product_name : ProductName ,
398365) -> None :
399- async with get_database_engine ( app ). acquire ( ) as conn , conn . begin () :
366+ async with transaction_context ( get_asyncpg_engine ( app ), connection ) as conn :
400367 # Step 1: Define the base case for the recursive CTE
401368 base_query = select (
402369 folders_v2 .c .folder_id , folders_v2 .c .parent_folder_id
@@ -422,10 +389,9 @@ async def delete_recursively(
422389
423390 # Step 4: Execute the query to get all descendants
424391 final_query = select (folder_hierarchy_cte )
425- result = await conn .execute (final_query )
426- rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
427- await result .fetchall () or []
428- )
392+ result = await conn .stream (final_query )
393+ # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
394+ rows = [row async for row in result ]
429395
430396 # Sort folders so that child folders come first
431397 sorted_folders = sorted (
@@ -439,6 +405,7 @@ async def delete_recursively(
439405
440406async def get_projects_recursively_only_if_user_is_owner (
441407 app : web .Application ,
408+ connection : AsyncConnection | None = None ,
442409 * ,
443410 folder_id : FolderID ,
444411 private_workspace_user_id_or_none : UserID | None ,
@@ -453,7 +420,8 @@ async def get_projects_recursively_only_if_user_is_owner(
453420 or the `users_to_groups` table for private workspace projects.
454421 """
455422
456- async with get_database_engine (app ).acquire () as conn , conn .begin ():
423+ async with pass_or_acquire_connection (get_asyncpg_engine (app ), connection ) as conn :
424+
457425 # Step 1: Define the base case for the recursive CTE
458426 base_query = select (
459427 folders_v2 .c .folder_id , folders_v2 .c .parent_folder_id
@@ -462,6 +430,7 @@ async def get_projects_recursively_only_if_user_is_owner(
462430 & (folders_v2 .c .product_name == product_name )
463431 )
464432 folder_hierarchy_cte = base_query .cte (name = "folder_hierarchy" , recursive = True )
433+
465434 # Step 2: Define the recursive case
466435 folder_alias = aliased (folders_v2 )
467436 recursive_query = select (
@@ -472,16 +441,15 @@ async def get_projects_recursively_only_if_user_is_owner(
472441 folder_alias .c .parent_folder_id == folder_hierarchy_cte .c .folder_id ,
473442 )
474443 )
444+
475445 # Step 3: Combine base and recursive cases into a CTE
476446 folder_hierarchy_cte = folder_hierarchy_cte .union_all (recursive_query )
447+
477448 # Step 4: Execute the query to get all descendants
478449 final_query = select (folder_hierarchy_cte )
479- result = await conn .execute (final_query )
480- rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
481- await result .fetchall () or []
482- )
483-
484- folder_ids = [item [0 ] for item in rows ]
450+ result = await conn .stream (final_query )
451+ # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
452+ folder_ids = [item [0 ] async for item in result ]
485453
486454 query = (
487455 select (projects_to_folders .c .project_uuid )
@@ -494,19 +462,19 @@ async def get_projects_recursively_only_if_user_is_owner(
494462 if private_workspace_user_id_or_none is not None :
495463 query = query .where (projects .c .prj_owner == user_id )
496464
497- result = await conn .execute (query )
498-
499- rows = await result .fetchall () or []
500- return [ProjectID (row [0 ]) for row in rows ]
465+ result = await conn .stream (query )
466+ return [ProjectID (row [0 ]) async for row in result ]
501467
502468
503469async def get_folders_recursively (
504470 app : web .Application ,
471+ connection : AsyncConnection | None = None ,
505472 * ,
506473 folder_id : FolderID ,
507474 product_name : ProductName ,
508475) -> list [FolderID ]:
509- async with get_database_engine (app ).acquire () as conn , conn .begin ():
476+ async with pass_or_acquire_connection (get_asyncpg_engine (app ), connection ) as conn :
477+
510478 # Step 1: Define the base case for the recursive CTE
511479 base_query = select (
512480 folders_v2 .c .folder_id , folders_v2 .c .parent_folder_id
@@ -532,9 +500,5 @@ async def get_folders_recursively(
532500
533501 # Step 4: Execute the query to get all descendants
534502 final_query = select (folder_hierarchy_cte )
535- result = await conn .execute (final_query )
536- rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
537- await result .fetchall () or []
538- )
539-
540- return [FolderID (row [0 ]) for row in rows ]
503+ result = await conn .stream (final_query )
504+ return [FolderID (row [0 ]) async for row in result ]
0 commit comments