1919from simcore_postgres_database .models .folders_v2 import folders_v2
2020from simcore_postgres_database .models .projects import projects
2121from simcore_postgres_database .models .projects_to_folders import projects_to_folders
22+ from simcore_postgres_database .utils_repos import (
23+ pass_or_acquire_connection ,
24+ transaction_context ,
25+ )
2226from sqlalchemy import func
27+ from sqlalchemy .ext .asyncio import AsyncConnection
2328from sqlalchemy .orm import aliased
2429from sqlalchemy .sql import asc , desc , select
2530
26- from ..db .plugin import get_database_engine
31+ from ..db .plugin import get_asyncpg_engine
2732from .errors import FolderAccessForbiddenError , FolderNotFoundError
2833
2934_logger = logging .getLogger (__name__ )
@@ -55,6 +60,7 @@ def as_dict_exclude_unset(**params) -> dict[str, Any]:
5560
5661async def create (
5762 app : web .Application ,
63+ connection : AsyncConnection | None = None ,
5864 * ,
5965 created_by_gid : GroupID ,
6066 folder_name : str ,
@@ -67,8 +73,8 @@ async def create(
6773 user_id is not None and workspace_id is not None
6874 ), "Both user_id and workspace_id cannot be provided at the same time. Please provide only one."
6975
70- async with get_database_engine ( app ). acquire ( ) as conn :
71- result = await conn .execute (
76+ async with transaction_context ( get_asyncpg_engine ( app ), connection ) as conn :
77+ result = await conn .stream (
7278 folders_v2 .insert ()
7379 .values (
7480 name = folder_name ,
@@ -88,6 +94,7 @@ async def create(
8894
8995async def list_ (
9096 app : web .Application ,
97+ connection : AsyncConnection | None = None ,
9198 * ,
9299 content_of_folder_id : FolderID | None ,
93100 user_id : UserID | None ,
@@ -142,18 +149,17 @@ async def list_(
142149 list_query = base_query .order_by (desc (getattr (folders_v2 .c , order_by .field )))
143150 list_query = list_query .offset (offset ).limit (limit )
144151
145- async with get_database_engine (app ).acquire () as conn :
146- count_result = await conn .execute (count_query )
147- total_count = await count_result .scalar ()
152+ async with pass_or_acquire_connection (get_asyncpg_engine (app ), connection ) as conn :
153+ total_count = await conn .scalar (count_query )
148154
149- result = await conn .execute (list_query )
150- rows = await result .fetchall () or []
151- results : list [FolderDB ] = [FolderDB .from_orm (row ) for row in rows ]
152- return cast (int , total_count ), results
155+ result = await conn .stream (list_query )
156+ folders : list [FolderDB ] = [FolderDB .from_orm (row ) async for row in result ]
157+ return cast (int , total_count ), folders
153158
154159
155160async def get (
156161 app : web .Application ,
162+ connection : AsyncConnection | None = None ,
157163 * ,
158164 folder_id : FolderID ,
159165 product_name : ProductName ,
@@ -167,8 +173,8 @@ async def get(
167173 )
168174 )
169175
170- async with get_database_engine ( app ). acquire ( ) as conn :
171- result = await conn .execute (query )
176+ async with pass_or_acquire_connection ( get_asyncpg_engine ( app ), connection ) as conn :
177+ result = await conn .stream (query )
172178 row = await result .first ()
173179 if row is None :
174180 raise FolderAccessForbiddenError (
@@ -179,6 +185,7 @@ async def get(
179185
180186async def get_for_user_or_workspace (
181187 app : web .Application ,
188+ connection : AsyncConnection | None = None ,
182189 * ,
183190 folder_id : FolderID ,
184191 product_name : ProductName ,
@@ -203,8 +210,8 @@ async def get_for_user_or_workspace(
203210 else :
204211 query = query .where (folders_v2 .c .workspace_id == workspace_id )
205212
206- async with get_database_engine ( app ). acquire ( ) as conn :
207- result = await conn .execute (query )
213+ async with pass_or_acquire_connection ( get_asyncpg_engine ( app ), connection ) as conn :
214+ result = await conn .stream (query )
208215 row = await result .first ()
209216 if row is None :
210217 raise FolderAccessForbiddenError (
@@ -213,8 +220,10 @@ async def get_for_user_or_workspace(
213220 return FolderDB .from_orm (row )
214221
215222
216- async def _update_impl (
223+ async def update (
217224 app : web .Application ,
225+ connection : AsyncConnection | None = None ,
226+ * ,
218227 folders_id_or_ids : FolderID | set [FolderID ],
219228 product_name : ProductName ,
220229 # updatable columns
@@ -247,64 +256,22 @@ async def _update_impl(
247256 # single-update
248257 query = query .where (folders_v2 .c .folder_id == folders_id_or_ids )
249258
250- async with get_database_engine ( app ). acquire ( ) as conn :
251- result = await conn .execute (query )
259+ async with transaction_context ( get_asyncpg_engine ( app ), connection ) as conn :
260+ result = await conn .stream (query )
252261 row = await result .first ()
253262 if row is None :
254263 raise FolderNotFoundError (reason = f"Folder { folders_id_or_ids } not found." )
255264 return FolderDB .from_orm (row )
256265
257266
258- async def update_batch (
259- app : web .Application ,
260- * folder_id : FolderID ,
261- product_name : ProductName ,
262- # updatable columns
263- name : str | UnSet = _unset ,
264- parent_folder_id : FolderID | None | UnSet = _unset ,
265- trashed_at : datetime | None | UnSet = _unset ,
266- trashed_explicitly : bool | UnSet = _unset ,
267- ) -> FolderDB :
268- return await _update_impl (
269- app = app ,
270- folders_id_or_ids = set (folder_id ),
271- product_name = product_name ,
272- name = name ,
273- parent_folder_id = parent_folder_id ,
274- trashed_at = trashed_at ,
275- trashed_explicitly = trashed_explicitly ,
276- )
277-
278-
279- async def update (
280- app : web .Application ,
281- * ,
282- folder_id : FolderID ,
283- product_name : ProductName ,
284- # updatable columns
285- name : str | UnSet = _unset ,
286- parent_folder_id : FolderID | None | UnSet = _unset ,
287- trashed_at : datetime | None | UnSet = _unset ,
288- trashed_explicitly : bool | UnSet = _unset ,
289- ) -> FolderDB :
290- return await _update_impl (
291- app = app ,
292- folders_id_or_ids = folder_id ,
293- product_name = product_name ,
294- name = name ,
295- parent_folder_id = parent_folder_id ,
296- trashed_at = trashed_at ,
297- trashed_explicitly = trashed_explicitly ,
298- )
299-
300-
301267async def delete_recursively (
302268 app : web .Application ,
269+ connection : AsyncConnection | None = None ,
303270 * ,
304271 folder_id : FolderID ,
305272 product_name : ProductName ,
306273) -> None :
307- async with get_database_engine ( app ). acquire ( ) as conn , conn . begin () :
274+ async with transaction_context ( get_asyncpg_engine ( app ), connection ) as conn :
308275 # Step 1: Define the base case for the recursive CTE
309276 base_query = select (
310277 folders_v2 .c .folder_id , folders_v2 .c .parent_folder_id
@@ -330,10 +297,9 @@ async def delete_recursively(
330297
331298 # Step 4: Execute the query to get all descendants
332299 final_query = select (folder_hierarchy_cte )
333- result = await conn .execute (final_query )
334- rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
335- await result .fetchall () or []
336- )
300+ result = await conn .stream (final_query )
301+ # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
302+ rows = [row async for row in result ]
337303
338304 # Sort folders so that child folders come first
339305 sorted_folders = sorted (
@@ -347,6 +313,7 @@ async def delete_recursively(
347313
348314async def get_projects_recursively_only_if_user_is_owner (
349315 app : web .Application ,
316+ connection : AsyncConnection | None = None ,
350317 * ,
351318 folder_id : FolderID ,
352319 private_workspace_user_id_or_none : UserID | None ,
@@ -361,7 +328,8 @@ async def get_projects_recursively_only_if_user_is_owner(
361328 or the `users_to_groups` table for private workspace projects.
362329 """
363330
364- async with get_database_engine (app ).acquire () as conn , conn .begin ():
331+ async with pass_or_acquire_connection (get_asyncpg_engine (app ), connection ) as conn :
332+
365333 # Step 1: Define the base case for the recursive CTE
366334 base_query = select (
367335 folders_v2 .c .folder_id , folders_v2 .c .parent_folder_id
@@ -370,6 +338,7 @@ async def get_projects_recursively_only_if_user_is_owner(
370338 & (folders_v2 .c .product_name == product_name )
371339 )
372340 folder_hierarchy_cte = base_query .cte (name = "folder_hierarchy" , recursive = True )
341+
373342 # Step 2: Define the recursive case
374343 folder_alias = aliased (folders_v2 )
375344 recursive_query = select (
@@ -380,16 +349,15 @@ async def get_projects_recursively_only_if_user_is_owner(
380349 folder_alias .c .parent_folder_id == folder_hierarchy_cte .c .folder_id ,
381350 )
382351 )
352+
383353 # Step 3: Combine base and recursive cases into a CTE
384354 folder_hierarchy_cte = folder_hierarchy_cte .union_all (recursive_query )
355+
385356 # Step 4: Execute the query to get all descendants
386357 final_query = select (folder_hierarchy_cte )
387- result = await conn .execute (final_query )
388- rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
389- await result .fetchall () or []
390- )
391-
392- folder_ids = [item [0 ] for item in rows ]
358+ result = await conn .stream (final_query )
359+ # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
360+ folder_ids = [item [0 ] async for item in result ]
393361
394362 query = (
395363 select (projects_to_folders .c .project_uuid )
@@ -402,19 +370,19 @@ async def get_projects_recursively_only_if_user_is_owner(
402370 if private_workspace_user_id_or_none is not None :
403371 query = query .where (projects .c .prj_owner == user_id )
404372
405- result = await conn .execute (query )
406-
407- rows = await result .fetchall () or []
408- return [ProjectID (row [0 ]) for row in rows ]
373+ result = await conn .stream (query )
374+ return [ProjectID (row [0 ]) async for row in result ]
409375
410376
411377async def get_folders_recursively (
412378 app : web .Application ,
379+ connection : AsyncConnection | None = None ,
413380 * ,
414381 folder_id : FolderID ,
415382 product_name : ProductName ,
416383) -> list [FolderID ]:
417- async with get_database_engine (app ).acquire () as conn , conn .begin ():
384+ async with pass_or_acquire_connection (get_asyncpg_engine (app ), connection ) as conn :
385+
418386 # Step 1: Define the base case for the recursive CTE
419387 base_query = select (
420388 folders_v2 .c .folder_id , folders_v2 .c .parent_folder_id
@@ -440,9 +408,5 @@ async def get_folders_recursively(
440408
441409 # Step 4: Execute the query to get all descendants
442410 final_query = select (folder_hierarchy_cte )
443- result = await conn .execute (final_query )
444- rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
445- await result .fetchall () or []
446- )
447-
448- return [FolderID (row [0 ]) for row in rows ]
411+ result = await conn .stream (final_query )
412+ return [FolderID (row [0 ]) async for row in result ]
0 commit comments