Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ async def update_folder(

folder_db = await folders_db.update(
app,
folder_id=folder_id,
folders_id_or_ids=folder_id,
name=name,
parent_folder_id=parent_folder_id,
product_name=product_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
from simcore_postgres_database.models.folders_v2 import folders_v2
from simcore_postgres_database.models.projects import projects
from simcore_postgres_database.models.projects_to_folders import projects_to_folders
from simcore_postgres_database.utils_repos import (
pass_or_acquire_connection,
transaction_context,
)
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncConnection
from sqlalchemy.orm import aliased
from sqlalchemy.sql import asc, desc, select

from ..db.plugin import get_database_engine
from ..db.plugin import get_asyncpg_engine
from .errors import FolderAccessForbiddenError, FolderNotFoundError

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,6 +60,7 @@ def as_dict_exclude_unset(**params) -> dict[str, Any]:

async def create(
app: web.Application,
connection: AsyncConnection | None = None,
*,
created_by_gid: GroupID,
folder_name: str,
Expand All @@ -67,8 +73,8 @@ async def create(
user_id is not None and workspace_id is not None
), "Both user_id and workspace_id cannot be provided at the same time. Please provide only one."

async with get_database_engine(app).acquire() as conn:
result = await conn.execute(
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
result = await conn.stream(
folders_v2.insert()
.values(
name=folder_name,
Expand All @@ -88,6 +94,7 @@ async def create(

async def list_(
app: web.Application,
connection: AsyncConnection | None = None,
*,
content_of_folder_id: FolderID | None,
user_id: UserID | None,
Expand Down Expand Up @@ -142,18 +149,17 @@ async def list_(
list_query = base_query.order_by(desc(getattr(folders_v2.c, order_by.field)))
list_query = list_query.offset(offset).limit(limit)

async with get_database_engine(app).acquire() as conn:
count_result = await conn.execute(count_query)
total_count = await count_result.scalar()
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
total_count = await conn.scalar(count_query)

result = await conn.execute(list_query)
rows = await result.fetchall() or []
results: list[FolderDB] = [FolderDB.from_orm(row) for row in rows]
return cast(int, total_count), results
result = await conn.stream(list_query)
folders: list[FolderDB] = [FolderDB.from_orm(row) async for row in result]
return cast(int, total_count), folders


async def get(
app: web.Application,
connection: AsyncConnection | None = None,
*,
folder_id: FolderID,
product_name: ProductName,
Expand All @@ -167,8 +173,8 @@ async def get(
)
)

async with get_database_engine(app).acquire() as conn:
result = await conn.execute(query)
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
result = await conn.stream(query)
row = await result.first()
if row is None:
raise FolderAccessForbiddenError(
Expand All @@ -179,6 +185,7 @@ async def get(

async def get_for_user_or_workspace(
app: web.Application,
connection: AsyncConnection | None = None,
*,
folder_id: FolderID,
product_name: ProductName,
Expand All @@ -203,8 +210,8 @@ async def get_for_user_or_workspace(
else:
query = query.where(folders_v2.c.workspace_id == workspace_id)

async with get_database_engine(app).acquire() as conn:
result = await conn.execute(query)
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
result = await conn.stream(query)
row = await result.first()
if row is None:
raise FolderAccessForbiddenError(
Expand All @@ -213,8 +220,10 @@ async def get_for_user_or_workspace(
return FolderDB.from_orm(row)


async def _update_impl(
async def update(
app: web.Application,
connection: AsyncConnection | None = None,
*,
folders_id_or_ids: FolderID | set[FolderID],
product_name: ProductName,
# updatable columns
Expand Down Expand Up @@ -247,64 +256,22 @@ async def _update_impl(
# single-update
query = query.where(folders_v2.c.folder_id == folders_id_or_ids)

async with get_database_engine(app).acquire() as conn:
result = await conn.execute(query)
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
result = await conn.stream(query)
row = await result.first()
if row is None:
raise FolderNotFoundError(reason=f"Folder {folders_id_or_ids} not found.")
return FolderDB.from_orm(row)


async def update_batch(
app: web.Application,
*folder_id: FolderID,
product_name: ProductName,
# updatable columns
name: str | UnSet = _unset,
parent_folder_id: FolderID | None | UnSet = _unset,
trashed_at: datetime | None | UnSet = _unset,
trashed_explicitly: bool | UnSet = _unset,
) -> FolderDB:
return await _update_impl(
app=app,
folders_id_or_ids=set(folder_id),
product_name=product_name,
name=name,
parent_folder_id=parent_folder_id,
trashed_at=trashed_at,
trashed_explicitly=trashed_explicitly,
)


async def update(
app: web.Application,
*,
folder_id: FolderID,
product_name: ProductName,
# updatable columns
name: str | UnSet = _unset,
parent_folder_id: FolderID | None | UnSet = _unset,
trashed_at: datetime | None | UnSet = _unset,
trashed_explicitly: bool | UnSet = _unset,
) -> FolderDB:
return await _update_impl(
app=app,
folders_id_or_ids=folder_id,
product_name=product_name,
name=name,
parent_folder_id=parent_folder_id,
trashed_at=trashed_at,
trashed_explicitly=trashed_explicitly,
)


async def delete_recursively(
app: web.Application,
connection: AsyncConnection | None = None,
*,
folder_id: FolderID,
product_name: ProductName,
) -> None:
async with get_database_engine(app).acquire() as conn, conn.begin():
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
# Step 1: Define the base case for the recursive CTE
base_query = select(
folders_v2.c.folder_id, folders_v2.c.parent_folder_id
Expand All @@ -330,10 +297,9 @@ async def delete_recursively(

# Step 4: Execute the query to get all descendants
final_query = select(folder_hierarchy_cte)
result = await conn.execute(final_query)
rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
await result.fetchall() or []
)
result = await conn.stream(final_query)
# list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
rows = [row async for row in result]

# Sort folders so that child folders come first
sorted_folders = sorted(
Expand All @@ -347,6 +313,7 @@ async def delete_recursively(

async def get_projects_recursively_only_if_user_is_owner(
app: web.Application,
connection: AsyncConnection | None = None,
*,
folder_id: FolderID,
private_workspace_user_id_or_none: UserID | None,
Expand All @@ -361,7 +328,8 @@ async def get_projects_recursively_only_if_user_is_owner(
or the `users_to_groups` table for private workspace projects.
"""

async with get_database_engine(app).acquire() as conn, conn.begin():
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:

# Step 1: Define the base case for the recursive CTE
base_query = select(
folders_v2.c.folder_id, folders_v2.c.parent_folder_id
Expand All @@ -370,6 +338,7 @@ async def get_projects_recursively_only_if_user_is_owner(
& (folders_v2.c.product_name == product_name)
)
folder_hierarchy_cte = base_query.cte(name="folder_hierarchy", recursive=True)

# Step 2: Define the recursive case
folder_alias = aliased(folders_v2)
recursive_query = select(
Expand All @@ -380,16 +349,15 @@ async def get_projects_recursively_only_if_user_is_owner(
folder_alias.c.parent_folder_id == folder_hierarchy_cte.c.folder_id,
)
)

# Step 3: Combine base and recursive cases into a CTE
folder_hierarchy_cte = folder_hierarchy_cte.union_all(recursive_query)

# Step 4: Execute the query to get all descendants
final_query = select(folder_hierarchy_cte)
result = await conn.execute(final_query)
rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
await result.fetchall() or []
)

folder_ids = [item[0] for item in rows]
result = await conn.stream(final_query)
# list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
folder_ids = [item[0] async for item in result]

query = (
select(projects_to_folders.c.project_uuid)
Expand All @@ -402,19 +370,19 @@ async def get_projects_recursively_only_if_user_is_owner(
if private_workspace_user_id_or_none is not None:
query = query.where(projects.c.prj_owner == user_id)

result = await conn.execute(query)

rows = await result.fetchall() or []
return [ProjectID(row[0]) for row in rows]
result = await conn.stream(query)
return [ProjectID(row[0]) async for row in result]


async def get_folders_recursively(
app: web.Application,
connection: AsyncConnection | None = None,
*,
folder_id: FolderID,
product_name: ProductName,
) -> list[FolderID]:
async with get_database_engine(app).acquire() as conn, conn.begin():
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:

# Step 1: Define the base case for the recursive CTE
base_query = select(
folders_v2.c.folder_id, folders_v2.c.parent_folder_id
Expand All @@ -440,9 +408,5 @@ async def get_folders_recursively(

# Step 4: Execute the query to get all descendants
final_query = select(folder_hierarchy_cte)
result = await conn.execute(final_query)
rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
await result.fetchall() or []
)

return [FolderID(row[0]) for row in rows]
result = await conn.stream(final_query)
return [FolderID(row[0]) async for row in result]
Loading
Loading