Skip to content

Commit 1feb56a

Browse files
review @pcrespov
2 parents 28eef15 + 8f182d3 commit 1feb56a

File tree

8 files changed

+153
-161
lines changed

8 files changed

+153
-161
lines changed

packages/postgres-database/src/simcore_postgres_database/utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,3 @@ def as_postgres_sql_query_str(statement) -> str:
8181
dialect=postgresql.dialect(), # type: ignore[misc]
8282
)
8383
return f"{compiled}"
84-
85-
86-
def assemble_array_groups(user_group_ids: list[int]) -> str:
87-
return (
88-
"array[]::text[]"
89-
if len(user_group_ids) == 0
90-
else f"""array[{', '.join(f"'{group_id}'" for group_id in user_group_ids)}]"""
91-
)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def assemble_array_groups(user_group_ids: list[int]) -> str:
2+
return (
3+
"array[]::text[]"
4+
if len(user_group_ids) == 0
5+
else f"""array[{', '.join(f"'{group_id}'" for group_id in user_group_ids)}]"""
6+
)

services/storage/src/simcore_service_storage/db_access_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
workspaces_access_rights,
5252
)
5353
from simcore_postgres_database.storage_models import file_meta_data, user_to_groups
54-
from simcore_postgres_database.utils import assemble_array_groups
54+
from simcore_postgres_database.utils_sql import assemble_array_groups
5555

5656
logger = logging.getLogger(__name__)
5757

services/web/server/src/simcore_service_webserver/folders/_folders_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ async def update_folder(
272272

273273
folder_db = await folders_db.update(
274274
app,
275-
folder_id=folder_id,
275+
folders_id_or_ids=folder_id,
276276
name=name,
277277
parent_folder_id=parent_folder_id,
278278
product_name=product_name,

services/web/server/src/simcore_service_webserver/folders/_folders_db.py

Lines changed: 48 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,17 @@
2323
from simcore_postgres_database.models.workspaces_access_rights import (
2424
workspaces_access_rights,
2525
)
26-
from simcore_postgres_database.utils import assemble_array_groups
26+
from simcore_postgres_database.utils_repos import (
27+
pass_or_acquire_connection,
28+
transaction_context,
29+
)
30+
from simcore_postgres_database.utils_sql import assemble_array_groups
2731
from sqlalchemy import func
32+
from sqlalchemy.ext.asyncio import AsyncConnection
2833
from sqlalchemy.orm import aliased
2934
from sqlalchemy.sql import ColumnElement, CompoundSelect, Select, asc, desc, select
3035

31-
from ..db.plugin import get_database_engine
36+
from ..db.plugin import get_asyncpg_engine
3237
from ..groups.api import list_all_user_groups
3338
from .errors import FolderAccessForbiddenError, FolderNotFoundError
3439

@@ -61,6 +66,7 @@ def as_dict_exclude_unset(**params) -> dict[str, Any]:
6166

6267
async def create(
6368
app: web.Application,
69+
connection: AsyncConnection | None = None,
6470
*,
6571
created_by_gid: GroupID,
6672
folder_name: str,
@@ -73,8 +79,8 @@ async def create(
7379
user_id is not None and workspace_id is not None
7480
), "Both user_id and workspace_id cannot be provided at the same time. Please provide only one."
7581

76-
async with get_database_engine(app).acquire() as conn:
77-
result = await conn.execute(
82+
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
83+
result = await conn.stream(
7884
folders_v2.insert()
7985
.values(
8086
name=folder_name,
@@ -94,6 +100,7 @@ async def create(
94100

95101
async def list_( # pylint: disable=too-many-arguments,too-many-branches
96102
app: web.Application,
103+
connection: AsyncConnection | None = None,
97104
*,
98105
product_name: ProductName,
99106
user_id: UserID,
@@ -234,18 +241,17 @@ async def list_( # pylint: disable=too-many-arguments,too-many-branches
234241
)
235242
list_query = list_query.offset(offset).limit(limit)
236243

237-
async with get_database_engine(app).acquire() as conn:
238-
count_result = await conn.execute(count_query)
239-
total_count = await count_result.scalar()
244+
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
245+
total_count = await conn.scalar(count_query)
240246

241-
result = await conn.execute(list_query)
242-
rows = await result.fetchall() or []
243-
results: list[FolderDB] = [FolderDB.from_orm(row) for row in rows]
244-
return cast(int, total_count), results
247+
result = await conn.stream(list_query)
248+
folders: list[FolderDB] = [FolderDB.from_orm(row) async for row in result]
249+
return cast(int, total_count), folders
245250

246251

247252
async def get(
248253
app: web.Application,
254+
connection: AsyncConnection | None = None,
249255
*,
250256
folder_id: FolderID,
251257
product_name: ProductName,
@@ -259,8 +265,8 @@ async def get(
259265
)
260266
)
261267

262-
async with get_database_engine(app).acquire() as conn:
263-
result = await conn.execute(query)
268+
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
269+
result = await conn.stream(query)
264270
row = await result.first()
265271
if row is None:
266272
raise FolderAccessForbiddenError(
@@ -271,6 +277,7 @@ async def get(
271277

272278
async def get_for_user_or_workspace(
273279
app: web.Application,
280+
connection: AsyncConnection | None = None,
274281
*,
275282
folder_id: FolderID,
276283
product_name: ProductName,
@@ -295,8 +302,8 @@ async def get_for_user_or_workspace(
295302
else:
296303
query = query.where(folders_v2.c.workspace_id == workspace_id)
297304

298-
async with get_database_engine(app).acquire() as conn:
299-
result = await conn.execute(query)
305+
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
306+
result = await conn.stream(query)
300307
row = await result.first()
301308
if row is None:
302309
raise FolderAccessForbiddenError(
@@ -305,8 +312,10 @@ async def get_for_user_or_workspace(
305312
return FolderDB.from_orm(row)
306313

307314

308-
async def _update_impl(
315+
async def update(
309316
app: web.Application,
317+
connection: AsyncConnection | None = None,
318+
*,
310319
folders_id_or_ids: FolderID | set[FolderID],
311320
product_name: ProductName,
312321
# updatable columns
@@ -339,64 +348,22 @@ async def _update_impl(
339348
# single-update
340349
query = query.where(folders_v2.c.folder_id == folders_id_or_ids)
341350

342-
async with get_database_engine(app).acquire() as conn:
343-
result = await conn.execute(query)
351+
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
352+
result = await conn.stream(query)
344353
row = await result.first()
345354
if row is None:
346355
raise FolderNotFoundError(reason=f"Folder {folders_id_or_ids} not found.")
347356
return FolderDB.from_orm(row)
348357

349358

350-
async def update_batch(
351-
app: web.Application,
352-
*folder_id: FolderID,
353-
product_name: ProductName,
354-
# updatable columns
355-
name: str | UnSet = _unset,
356-
parent_folder_id: FolderID | None | UnSet = _unset,
357-
trashed_at: datetime | None | UnSet = _unset,
358-
trashed_explicitly: bool | UnSet = _unset,
359-
) -> FolderDB:
360-
return await _update_impl(
361-
app=app,
362-
folders_id_or_ids=set(folder_id),
363-
product_name=product_name,
364-
name=name,
365-
parent_folder_id=parent_folder_id,
366-
trashed_at=trashed_at,
367-
trashed_explicitly=trashed_explicitly,
368-
)
369-
370-
371-
async def update(
372-
app: web.Application,
373-
*,
374-
folder_id: FolderID,
375-
product_name: ProductName,
376-
# updatable columns
377-
name: str | UnSet = _unset,
378-
parent_folder_id: FolderID | None | UnSet = _unset,
379-
trashed_at: datetime | None | UnSet = _unset,
380-
trashed_explicitly: bool | UnSet = _unset,
381-
) -> FolderDB:
382-
return await _update_impl(
383-
app=app,
384-
folders_id_or_ids=folder_id,
385-
product_name=product_name,
386-
name=name,
387-
parent_folder_id=parent_folder_id,
388-
trashed_at=trashed_at,
389-
trashed_explicitly=trashed_explicitly,
390-
)
391-
392-
393359
async def delete_recursively(
394360
app: web.Application,
361+
connection: AsyncConnection | None = None,
395362
*,
396363
folder_id: FolderID,
397364
product_name: ProductName,
398365
) -> None:
399-
async with get_database_engine(app).acquire() as conn, conn.begin():
366+
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
400367
# Step 1: Define the base case for the recursive CTE
401368
base_query = select(
402369
folders_v2.c.folder_id, folders_v2.c.parent_folder_id
@@ -422,10 +389,9 @@ async def delete_recursively(
422389

423390
# Step 4: Execute the query to get all descendants
424391
final_query = select(folder_hierarchy_cte)
425-
result = await conn.execute(final_query)
426-
rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
427-
await result.fetchall() or []
428-
)
392+
result = await conn.stream(final_query)
393+
# list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
394+
rows = [row async for row in result]
429395

430396
# Sort folders so that child folders come first
431397
sorted_folders = sorted(
@@ -439,6 +405,7 @@ async def delete_recursively(
439405

440406
async def get_projects_recursively_only_if_user_is_owner(
441407
app: web.Application,
408+
connection: AsyncConnection | None = None,
442409
*,
443410
folder_id: FolderID,
444411
private_workspace_user_id_or_none: UserID | None,
@@ -453,7 +420,8 @@ async def get_projects_recursively_only_if_user_is_owner(
453420
or the `users_to_groups` table for private workspace projects.
454421
"""
455422

456-
async with get_database_engine(app).acquire() as conn, conn.begin():
423+
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
424+
457425
# Step 1: Define the base case for the recursive CTE
458426
base_query = select(
459427
folders_v2.c.folder_id, folders_v2.c.parent_folder_id
@@ -462,6 +430,7 @@ async def get_projects_recursively_only_if_user_is_owner(
462430
& (folders_v2.c.product_name == product_name)
463431
)
464432
folder_hierarchy_cte = base_query.cte(name="folder_hierarchy", recursive=True)
433+
465434
# Step 2: Define the recursive case
466435
folder_alias = aliased(folders_v2)
467436
recursive_query = select(
@@ -472,16 +441,15 @@ async def get_projects_recursively_only_if_user_is_owner(
472441
folder_alias.c.parent_folder_id == folder_hierarchy_cte.c.folder_id,
473442
)
474443
)
444+
475445
# Step 3: Combine base and recursive cases into a CTE
476446
folder_hierarchy_cte = folder_hierarchy_cte.union_all(recursive_query)
447+
477448
# Step 4: Execute the query to get all descendants
478449
final_query = select(folder_hierarchy_cte)
479-
result = await conn.execute(final_query)
480-
rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
481-
await result.fetchall() or []
482-
)
483-
484-
folder_ids = [item[0] for item in rows]
450+
result = await conn.stream(final_query)
451+
# list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
452+
folder_ids = [item[0] async for item in result]
485453

486454
query = (
487455
select(projects_to_folders.c.project_uuid)
@@ -494,19 +462,19 @@ async def get_projects_recursively_only_if_user_is_owner(
494462
if private_workspace_user_id_or_none is not None:
495463
query = query.where(projects.c.prj_owner == user_id)
496464

497-
result = await conn.execute(query)
498-
499-
rows = await result.fetchall() or []
500-
return [ProjectID(row[0]) for row in rows]
465+
result = await conn.stream(query)
466+
return [ProjectID(row[0]) async for row in result]
501467

502468

503469
async def get_folders_recursively(
504470
app: web.Application,
471+
connection: AsyncConnection | None = None,
505472
*,
506473
folder_id: FolderID,
507474
product_name: ProductName,
508475
) -> list[FolderID]:
509-
async with get_database_engine(app).acquire() as conn, conn.begin():
476+
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
477+
510478
# Step 1: Define the base case for the recursive CTE
511479
base_query = select(
512480
folders_v2.c.folder_id, folders_v2.c.parent_folder_id
@@ -532,9 +500,5 @@ async def get_folders_recursively(
532500

533501
# Step 4: Execute the query to get all descendants
534502
final_query = select(folder_hierarchy_cte)
535-
result = await conn.execute(final_query)
536-
rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
537-
await result.fetchall() or []
538-
)
539-
540-
return [FolderID(row[0]) for row in rows]
503+
result = await conn.stream(final_query)
504+
return [FolderID(row[0]) async for row in result]

0 commit comments

Comments
 (0)