Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/specs/web-server/_nih_sparc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fastapi import APIRouter
from models_library.generics import Envelope
from simcore_service_webserver._meta import API_VTAG
from simcore_service_webserver.studies_dispatcher._rest_handlers import (
from simcore_service_webserver.studies_dispatcher._controller.rest.nih_schemas import (
ServiceGet,
Viewer,
)
Expand Down
7 changes: 3 additions & 4 deletions api/specs/web-server/_nih_sparc_redirections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
""" Helper script to generate OAS automatically NIH-sparc portal API section
"""
"""Helper script to generate OAS automatically NIH-sparc portal API section"""

# pylint: disable=protected-access
# pylint: disable=redefined-outer-name
Expand All @@ -11,7 +10,7 @@
from fastapi import APIRouter, status
from fastapi.responses import RedirectResponse
from models_library.projects import ProjectID
from models_library.services import ServiceKey, ServiceKeyVersion
from models_library.services_types import ServiceKey, ServiceVersion
from pydantic import HttpUrl, PositiveInt

router = APIRouter(
Expand All @@ -31,7 +30,7 @@
async def get_redirection_to_viewer(
file_type: str,
viewer_key: ServiceKey,
viewer_version: ServiceKeyVersion,
viewer_version: ServiceVersion,
file_size: PositiveInt,
download_link: HttpUrl,
file_name: str | None = "unknown",
Expand Down
2 changes: 1 addition & 1 deletion packages/pytest-simcore/src/pytest_simcore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def keep_docker_up(request: pytest.FixtureRequest) -> bool:
return flag


@pytest.fixture
@pytest.fixture(scope="session")
def is_pdb_enabled(request: pytest.FixtureRequest):
"""Returns true if tests are set to use interactive debugger, i.e. --pdb"""
options = request.config.option
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,3 +571,32 @@ def random_itis_vip_available_download_item(

data.update(**overrides)
return data


def random_service_consume_filetype(
*,
service_key: str,
service_version: str,
fake: Faker = DEFAULT_FAKER,
**overrides,
) -> dict[str, Any]:
from simcore_postgres_database.models.services_consume_filetypes import (
services_consume_filetypes,
)

data = {
"service_key": service_key,
"service_version": service_version,
"service_display_name": fake.company(),
"service_input_port": fake.word(),
"filetype": fake.random_element(["CSV", "VTK", "H5", "JSON", "TXT"]),
"preference_order": fake.pyint(min_value=0, max_value=10),
"is_guest_allowed": fake.pybool(),
}

assert set(data.keys()).issubset( # nosec
{c.name for c in services_consume_filetypes.columns}
)

data.update(overrides)
return data
254 changes: 224 additions & 30 deletions packages/pytest-simcore/src/pytest_simcore/helpers/postgres_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,41 +88,107 @@ async def _async_insert_and_get_row(
conn: AsyncConnection,
table: sa.Table,
values: dict[str, Any],
pk_col: sa.Column,
pk_col: sa.Column | None = None,
pk_value: Any | None = None,
pk_cols: list[sa.Column] | None = None,
pk_values: list[Any] | None = None,
) -> sa.engine.Row:
result = await conn.execute(table.insert().values(**values).returning(pk_col))
# Validate parameters
single_pk_provided = pk_col is not None
composite_pk_provided = pk_cols is not None

if single_pk_provided == composite_pk_provided:
msg = "Must provide either pk_col or pk_cols, but not both"
raise ValueError(msg)

if composite_pk_provided:
if pk_values is not None and len(pk_cols) != len(pk_values):
msg = "pk_cols and pk_values must have the same length"
raise ValueError(msg)
returning_cols = pk_cols
else:
returning_cols = [pk_col]

result = await conn.execute(
table.insert().values(**values).returning(*returning_cols)
)
row = result.one()

# Get the pk_value from the row if not provided
if pk_value is None:
pk_value = getattr(row, pk_col.name)
if composite_pk_provided:
# Handle composite primary keys
if pk_values is None:
pk_values = [getattr(row, col.name) for col in pk_cols]
else:
for col, expected_value in zip(pk_cols, pk_values, strict=True):
assert getattr(row, col.name) == expected_value

# Build WHERE clause for composite key
where_clause = sa.and_(
*[col == val for col, val in zip(pk_cols, pk_values, strict=True)]
)
else:
# NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
assert getattr(row, pk_col.name) == pk_value
# Handle single primary key (existing logic)
if pk_value is None:
pk_value = getattr(row, pk_col.name)
else:
assert getattr(row, pk_col.name) == pk_value

where_clause = pk_col == pk_value

result = await conn.execute(sa.select(table).where(pk_col == pk_value))
result = await conn.execute(sa.select(table).where(where_clause))
return result.one()


def _sync_insert_and_get_row(
conn: sa.engine.Connection,
table: sa.Table,
values: dict[str, Any],
pk_col: sa.Column,
pk_col: sa.Column | None = None,
pk_value: Any | None = None,
pk_cols: list[sa.Column] | None = None,
pk_values: list[Any] | None = None,
) -> sa.engine.Row:
result = conn.execute(table.insert().values(**values).returning(pk_col))
# Validate parameters
single_pk_provided = pk_col is not None
composite_pk_provided = pk_cols is not None

if single_pk_provided == composite_pk_provided:
msg = "Must provide either pk_col or pk_cols, but not both"
raise ValueError(msg)

if composite_pk_provided:
if pk_values is not None and len(pk_cols) != len(pk_values):
msg = "pk_cols and pk_values must have the same length"
raise ValueError(msg)
returning_cols = pk_cols
else:
returning_cols = [pk_col]

result = conn.execute(table.insert().values(**values).returning(*returning_cols))
row = result.one()

# Get the pk_value from the row if not provided
if pk_value is None:
pk_value = getattr(row, pk_col.name)
if composite_pk_provided:
# Handle composite primary keys
if pk_values is None:
pk_values = [getattr(row, col.name) for col in pk_cols]
else:
for col, expected_value in zip(pk_cols, pk_values, strict=True):
assert getattr(row, col.name) == expected_value

# Build WHERE clause for composite key
where_clause = sa.and_(
*[col == val for col, val in zip(pk_cols, pk_values, strict=True)]
)
else:
# NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
assert getattr(row, pk_col.name) == pk_value
# Handle single primary key (existing logic)
if pk_value is None:
pk_value = getattr(row, pk_col.name)
else:
assert getattr(row, pk_col.name) == pk_value

where_clause = pk_col == pk_value

result = conn.execute(sa.select(table).where(pk_col == pk_value))
result = conn.execute(sa.select(table).where(where_clause))
return result.one()


Expand All @@ -132,27 +198,135 @@ async def insert_and_get_row_lifespan(
*,
table: sa.Table,
values: dict[str, Any],
pk_col: sa.Column,
pk_col: sa.Column | None = None,
pk_value: Any | None = None,
pk_cols: list[sa.Column] | None = None,
pk_values: list[Any] | None = None,
) -> AsyncIterator[dict[str, Any]]:
"""
Context manager that inserts a row into a table and automatically deletes it on exit.

Args:
sqlalchemy_async_engine: Async SQLAlchemy engine
table: The table to insert into
values: Dictionary of column values to insert
pk_col: Primary key column for deletion (for single-column primary keys)
pk_value: Optional primary key value (if None, will be taken from inserted row)
pk_cols: List of primary key columns (for composite primary keys)
pk_values: Optional list of primary key values (if None, will be taken from inserted row)

Yields:
dict: The inserted row as a dictionary

Examples:
## Single primary key usage:

@pytest.fixture
async def user_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[dict]:
user_data = random_user(name="test_user", email="[email protected]")
async with insert_and_get_row_lifespan(
asyncpg_engine,
table=users,
values=user_data,
pk_col=users.c.id,
) as row:
yield row

##Composite primary key usage:

@pytest.fixture
async def service_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[dict]:
service_data = {"key": "simcore/services/comp/test", "version": "1.0.0", "name": "Test Service"}
async with insert_and_get_row_lifespan(
asyncpg_engine,
table=services,
values=service_data,
pk_cols=[services.c.key, services.c.version],
) as row:
yield row

##Multiple rows with single primary keys using AsyncExitStack:

@pytest.fixture
async def users_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[list[dict]]:
users_data = [
random_user(name="user1", email="[email protected]"),
random_user(name="user2", email="[email protected]"),
]

async with AsyncExitStack() as stack:
created_users = []
for user_data in users_data:
row = await stack.enter_async_context(
insert_and_get_row_lifespan(
asyncpg_engine,
table=users,
values=user_data,
pk_col=users.c.id,
)
)
created_users.append(row)

yield created_users

## Multiple rows with composite primary keys using AsyncExitStack:

@pytest.fixture
async def services_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[list[dict]]:
services_data = [
{"key": "simcore/services/comp/service1", "version": "1.0.0", "name": "Service 1"},
{"key": "simcore/services/comp/service2", "version": "2.0.0", "name": "Service 2"},
{"key": "simcore/services/comp/service1", "version": "2.0.0", "name": "Service 1 v2"},
]

async with AsyncExitStack() as stack:
created_services = []
for service_data in services_data:
row = await stack.enter_async_context(
insert_and_get_row_lifespan(
asyncpg_engine,
table=services,
values=service_data,
pk_cols=[services.c.key, services.c.version],
)
)
created_services.append(row)

yield created_services
"""
# SETUP: insert & get
async with sqlalchemy_async_engine.begin() as conn:
row = await _async_insert_and_get_row(
conn, table=table, values=values, pk_col=pk_col, pk_value=pk_value
conn,
table=table,
values=values,
pk_col=pk_col,
pk_value=pk_value,
pk_cols=pk_cols,
pk_values=pk_values,
)
# If pk_value was None, get it from the row for deletion later
if pk_value is None:
pk_value = getattr(row, pk_col.name)

# Get pk values for deletion
if pk_cols is not None:
if pk_values is None:
pk_values = [getattr(row, col.name) for col in pk_cols]
where_clause = sa.and_(
*[col == val for col, val in zip(pk_cols, pk_values, strict=True)]
)
else:
if pk_value is None:
pk_value = getattr(row, pk_col.name)
where_clause = pk_col == pk_value

assert row

# NOTE: DO NO USE dict(row) since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
# pylint: disable=protected-access
yield row._asdict()

# TEAD-DOWN: delete row
# TEARDOWN: delete row
async with sqlalchemy_async_engine.begin() as conn:
await conn.execute(table.delete().where(pk_col == pk_value))
await conn.execute(table.delete().where(where_clause))


@contextmanager
Expand All @@ -161,23 +335,43 @@ def sync_insert_and_get_row_lifespan(
*,
table: sa.Table,
values: dict[str, Any],
pk_col: sa.Column,
pk_col: sa.Column | None = None,
pk_value: Any | None = None,
pk_cols: list[sa.Column] | None = None,
pk_values: list[Any] | None = None,
) -> Iterator[dict[str, Any]]:
"""sync version of insert_and_get_row_lifespan.

TIP: more convenient for **module-scope fixtures** that setup the
database tables before the app starts since it does not require an `event_loop`
fixture (which is funcition-scoped )
fixture (which is function-scoped)

Supports both single and composite primary keys using the same parameter patterns
as the async version.
"""
# SETUP: insert & get
with sqlalchemy_sync_engine.begin() as conn:
row = _sync_insert_and_get_row(
conn, table=table, values=values, pk_col=pk_col, pk_value=pk_value
conn,
table=table,
values=values,
pk_col=pk_col,
pk_value=pk_value,
pk_cols=pk_cols,
pk_values=pk_values,
)
# If pk_value was None, get it from the row for deletion later
if pk_value is None:
pk_value = getattr(row, pk_col.name)

# Get pk values for deletion
if pk_cols is not None:
if pk_values is None:
pk_values = [getattr(row, col.name) for col in pk_cols]
where_clause = sa.and_(
*[col == val for col, val in zip(pk_cols, pk_values, strict=True)]
)
else:
if pk_value is None:
pk_value = getattr(row, pk_col.name)
where_clause = pk_col == pk_value

assert row

Expand All @@ -187,4 +381,4 @@ def sync_insert_and_get_row_lifespan(

# TEARDOWN: delete row
with sqlalchemy_sync_engine.begin() as conn:
conn.execute(table.delete().where(pk_col == pk_value))
conn.execute(table.delete().where(where_clause))
Loading
Loading