diff --git a/api/specs/web-server/_products.py b/api/specs/web-server/_products.py index 260ee88c5af4..a77f50f31937 100644 --- a/api/specs/web-server/_products.py +++ b/api/specs/web-server/_products.py @@ -8,7 +8,7 @@ from typing import Annotated from fastapi import APIRouter, Depends -from models_library.api_schemas_webserver.product import ( +from models_library.api_schemas_webserver.products import ( CreditPriceGet, InvitationGenerate, InvitationGenerated, @@ -17,7 +17,9 @@ ) from models_library.generics import Envelope from simcore_service_webserver._meta import API_VTAG -from simcore_service_webserver.products._rest_schemas import ProductsRequestParams +from simcore_service_webserver.products._controller.rest_schemas import ( + ProductsRequestParams, +) router = APIRouter( prefix=f"/{API_VTAG}", @@ -31,8 +33,7 @@ "/credits-price", response_model=Envelope[CreditPriceGet], ) -async def get_current_product_price(): - ... +async def get_current_product_price(): ... @router.get( @@ -43,16 +44,14 @@ async def get_current_product_price(): "po", ], ) -async def get_product(_params: Annotated[ProductsRequestParams, Depends()]): - ... +async def get_product(_params: Annotated[ProductsRequestParams, Depends()]): ... @router.get( "/products/current/ui", response_model=Envelope[ProductUIGet], ) -async def get_current_product_ui(): - ... +async def get_current_product_ui(): ... @router.post( @@ -62,5 +61,4 @@ async def get_current_product_ui(): "po", ], ) -async def generate_invitation(_body: InvitationGenerate): - ... +async def generate_invitation(_body: InvitationGenerate): ... diff --git a/packages/models-library/src/models_library/api_schemas_webserver/product.py b/packages/models-library/src/models_library/api_schemas_webserver/products.py similarity index 88% rename from packages/models-library/src/models_library/api_schemas_webserver/product.py rename to packages/models-library/src/models_library/api_schemas_webserver/products.py index 475361d8ca42..61f03a2c5e95 100644 --- a/packages/models-library/src/models_library/api_schemas_webserver/product.py +++ b/packages/models-library/src/models_library/api_schemas_webserver/products.py @@ -1,8 +1,10 @@ from datetime import datetime +from decimal import Decimal from typing import Annotated, Any, TypeAlias from common_library.basic_types import DEFAULT_FACTORY from pydantic import ( + BaseModel, ConfigDict, Field, HttpUrl, @@ -19,6 +21,28 @@ from ._base import InputSchema, OutputSchema +class CreditResultRpcGet(BaseModel): + product_name: ProductName + credit_amount: Decimal + + @staticmethod + def _update_json_schema_extra(schema: JsonDict) -> None: + schema.update( + { + "examples": [ + { + "product_name": "s4l", + "credit_amount": Decimal("15.5"), # type: ignore[dict-item] + }, + ] + } + ) + + model_config = ConfigDict( + json_schema_extra=_update_json_schema_extra, + ) + + class CreditPriceGet(OutputSchema): product_name: str usd_per_credit: Annotated[ diff --git a/packages/models-library/src/models_library/products.py b/packages/models-library/src/models_library/products.py index 51c44a83d478..d9f25a000f56 100644 --- a/packages/models-library/src/models_library/products.py +++ b/packages/models-library/src/models_library/products.py @@ -1,36 +1,5 @@ -from decimal import Decimal from typing import TypeAlias -from pydantic import BaseModel, ConfigDict, Field - ProductName: TypeAlias = str StripePriceID: TypeAlias = str StripeTaxRateID: TypeAlias = str - - -class CreditResultGet(BaseModel): - product_name: ProductName - credit_amount: Decimal = Field(..., description="") - - model_config = ConfigDict( - json_schema_extra={ - "examples": [ - {"product_name": "s4l", "credit_amount": Decimal(15.5)}, # type: ignore[dict-item] - ] - } - ) - - -class ProductStripeInfoGet(BaseModel): - stripe_price_id: StripePriceID - stripe_tax_rate_id: StripeTaxRateID - model_config = ConfigDict( - json_schema_extra={ - "examples": [ - { - "stripe_price_id": "stripe-price-id", - "stripe_tax_rate_id": "stripe-tax-rate-id", - }, - ] - } - ) diff --git a/packages/postgres-database/src/simcore_postgres_database/errors.py b/packages/postgres-database/src/simcore_postgres_database/aiopg_errors.py similarity index 53% rename from packages/postgres-database/src/simcore_postgres_database/errors.py rename to packages/postgres-database/src/simcore_postgres_database/aiopg_errors.py index 9c4fb417854c..730d6f630ac1 100644 --- a/packages/postgres-database/src/simcore_postgres_database/errors.py +++ b/packages/postgres-database/src/simcore_postgres_database/aiopg_errors.py @@ -1,25 +1,32 @@ -""" aiopg errors +"""aiopg errors - StandardError - |__ Warning - |__ Error - |__ InterfaceError - |__ DatabaseError - |__ DataError - |__ OperationalError - |__ IntegrityError - |__ InternalError - |__ ProgrammingError - |__ NotSupportedError +WARNING: these errors are not raised by asyncpg. Therefore all code using new sqlalchemy.ext.asyncio + MUST use instead import sqlalchemy.exc exceptions!!!! - - aiopg reuses DBAPI exceptions - SEE https://aiopg.readthedocs.io/en/stable/core.html?highlight=Exception#exceptions - SEE http://initd.org/psycopg/docs/module.html#dbapi-exceptions - SEE https://www.postgresql.org/docs/current/errcodes-appendix.html +StandardError +|__ Warning +|__ Error + |__ InterfaceError + |__ DatabaseError + |__ DataError + |__ OperationalError + |__ IntegrityError + |__ InternalError + |__ ProgrammingError + |__ NotSupportedError + +- aiopg reuses DBAPI exceptions + SEE https://aiopg.readthedocs.io/en/stable/core.html?highlight=Exception#exceptions + SEE http://initd.org/psycopg/docs/module.html#dbapi-exceptions + SEE https://www.postgresql.org/docs/current/errcodes-appendix.html """ + # NOTE: psycopg2.errors are created dynamically # pylint: disable=no-name-in-module -from psycopg2 import DatabaseError, DataError +from psycopg2 import ( + DatabaseError, + DataError, +) from psycopg2 import Error as DBAPIError from psycopg2 import ( IntegrityError, diff --git a/packages/postgres-database/src/simcore_postgres_database/utils_payments.py b/packages/postgres-database/src/simcore_postgres_database/utils_payments.py index 7202eb21d742..de4db3abe11b 100644 --- a/packages/postgres-database/src/simcore_postgres_database/utils_payments.py +++ b/packages/postgres-database/src/simcore_postgres_database/utils_payments.py @@ -8,7 +8,7 @@ from aiopg.sa.connection import SAConnection from aiopg.sa.result import ResultProxy, RowProxy -from . import errors +from . import aiopg_errors from .models.payments_transactions import PaymentTransactionState, payments_transactions _logger = logging.getLogger(__name__) @@ -29,16 +29,13 @@ def __bool__(self): return False -class PaymentAlreadyExists(PaymentFailure): - ... +class PaymentAlreadyExists(PaymentFailure): ... -class PaymentNotFound(PaymentFailure): - ... +class PaymentNotFound(PaymentFailure): ... -class PaymentAlreadyAcked(PaymentFailure): - ... +class PaymentAlreadyAcked(PaymentFailure): ... async def insert_init_payment_transaction( @@ -69,7 +66,7 @@ async def insert_init_payment_transaction( initiated_at=initiated_at, ) ) - except errors.UniqueViolation: + except aiopg_errors.UniqueViolation: return PaymentAlreadyExists(payment_id) return payment_id diff --git a/packages/postgres-database/src/simcore_postgres_database/utils_products.py b/packages/postgres-database/src/simcore_postgres_database/utils_products.py index 33e877c21d09..dba8caf074b3 100644 --- a/packages/postgres-database/src/simcore_postgres_database/utils_products.py +++ b/packages/postgres-database/src/simcore_postgres_database/utils_products.py @@ -1,12 +1,8 @@ -""" Common functions to access products table - -""" - -import warnings +"""Common functions to access products table""" import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncConnection -from ._protocols import AiopgConnection, DBConnection from .models.groups import GroupType, groups from .models.products import products @@ -14,7 +10,10 @@ _GroupID = int -async def get_default_product_name(conn: DBConnection) -> str: +class EmptyProductsError(ValueError): ... + + +async def get_default_product_name(conn: AsyncConnection) -> str: """The first row in the table is considered as the default product :: raises ValueError if undefined @@ -23,15 +22,15 @@ async def get_default_product_name(conn: DBConnection) -> str: sa.select(products.c.name).order_by(products.c.priority) ) if not product_name: - msg = "No product defined in database" - raise ValueError(msg) + msg = "No product was defined in database. Upon construction, at least one product is added but there are none." + raise EmptyProductsError(msg) assert isinstance(product_name, str) # nosec return product_name -async def get_product_group_id( - connection: DBConnection, product_name: str +async def get_product_group_id_or_none( + connection: AsyncConnection, product_name: str ) -> _GroupID | None: group_id = await connection.scalar( sa.select(products.c.group_id).where(products.c.name == product_name) @@ -39,7 +38,9 @@ async def get_product_group_id( return None if group_id is None else _GroupID(group_id) -async def execute_get_or_create_product_group(conn, product_name: str) -> int: +async def get_or_create_product_group( + conn: AsyncConnection, product_name: str +) -> _GroupID: # # NOTE: Separated so it can be used in asyncpg and aiopg environs while both # coexist @@ -70,23 +71,3 @@ async def execute_get_or_create_product_group(conn, product_name: str) -> int: ) return group_id - - -async def get_or_create_product_group( - connection: AiopgConnection, product_name: str -) -> _GroupID: - """ - Returns group_id of a product. Creates it if undefined - """ - warnings.warn( - f"{__name__}.get_or_create_product_group uses aiopg which has been deprecated in this repo. Please use the asyncpg equivalent version instead" - "See https://github.com/ITISFoundation/osparc-simcore/issues/4529", - DeprecationWarning, - stacklevel=1, - ) - - async with connection.begin(): - group_id = await execute_get_or_create_product_group( - connection, product_name=product_name - ) - return _GroupID(group_id) diff --git a/packages/postgres-database/src/simcore_postgres_database/utils_products_prices.py b/packages/postgres-database/src/simcore_postgres_database/utils_products_prices.py index b573b78e4153..549bcd116e26 100644 --- a/packages/postgres-database/src/simcore_postgres_database/utils_products_prices.py +++ b/packages/postgres-database/src/simcore_postgres_database/utils_products_prices.py @@ -2,7 +2,7 @@ from typing import NamedTuple, TypeAlias import sqlalchemy as sa -from aiopg.sa.connection import SAConnection +from sqlalchemy.ext.asyncio import AsyncConnection from .constants import QUANTIZE_EXP_ARG from .models.products_prices import products_prices @@ -17,9 +17,9 @@ class ProductPriceInfo(NamedTuple): async def get_product_latest_price_info_or_none( - conn: SAConnection, product_name: str + conn: AsyncConnection, product_name: str ) -> ProductPriceInfo | None: - """None menans the product is not billable""" + """If the product is not billable, it returns None""" # newest price of a product result = await conn.execute( sa.select( @@ -30,7 +30,7 @@ async def get_product_latest_price_info_or_none( .order_by(sa.desc(products_prices.c.valid_from)) .limit(1) ) - row = await result.first() + row = result.one_or_none() if row and row.usd_per_credit is not None: assert row.min_payment_amount_usd is not None # nosec @@ -43,27 +43,24 @@ async def get_product_latest_price_info_or_none( return None -async def get_product_latest_stripe_info( - conn: SAConnection, product_name: str -) -> tuple[StripePriceID, StripeTaxRateID]: +async def get_product_latest_stripe_info_or_none( + conn: AsyncConnection, product_name: str +) -> tuple[StripePriceID, StripeTaxRateID] | None: # Stripe info of a product for latest price - row = await ( - await conn.execute( - sa.select( - products_prices.c.stripe_price_id, - products_prices.c.stripe_tax_rate_id, - ) - .where(products_prices.c.product_name == product_name) - .order_by(sa.desc(products_prices.c.valid_from)) - .limit(1) + result = await conn.execute( + sa.select( + products_prices.c.stripe_price_id, + products_prices.c.stripe_tax_rate_id, ) - ).fetchone() - if row is None: - msg = f"Required Stripe information missing from product {product_name=}" - raise ValueError(msg) - return (row.stripe_price_id, row.stripe_tax_rate_id) + .where(products_prices.c.product_name == product_name) + .order_by(sa.desc(products_prices.c.valid_from)) + .limit(1) + ) + + row = result.one_or_none() + return (row.stripe_price_id, row.stripe_tax_rate_id) if row else None -async def is_payment_enabled(conn: SAConnection, product_name: str) -> bool: +async def is_payment_enabled(conn: AsyncConnection, product_name: str) -> bool: p = await get_product_latest_price_info_or_none(conn, product_name=product_name) return bool(p) # zero or None is disabled diff --git a/packages/postgres-database/src/simcore_postgres_database/utils_projects_metadata.py b/packages/postgres-database/src/simcore_postgres_database/utils_projects_metadata.py index 149bb50b6a1c..9140dd5e43ee 100644 --- a/packages/postgres-database/src/simcore_postgres_database/utils_projects_metadata.py +++ b/packages/postgres-database/src/simcore_postgres_database/utils_projects_metadata.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, ConfigDict from sqlalchemy.dialects.postgresql import insert as pg_insert -from .errors import ForeignKeyViolation +from .aiopg_errors import ForeignKeyViolation from .models.projects import projects from .models.projects_metadata import projects_metadata @@ -33,11 +33,15 @@ class DBProjectInvalidAncestorsError(BaseProjectsMetadataError): class DBProjectInvalidParentProjectError(BaseProjectsMetadataError): - msg_template: str = "Project project_uuid={project_uuid!r} has invalid parent project uuid={parent_project_uuid!r}" + msg_template: str = ( + "Project project_uuid={project_uuid!r} has invalid parent project uuid={parent_project_uuid!r}" + ) class DBProjectInvalidParentNodeError(BaseProjectsMetadataError): - msg_template: str = "Project project_uuid={project_uuid!r} has invalid parent project uuid={parent_node_id!r}" + msg_template: str = ( + "Project project_uuid={project_uuid!r} has invalid parent project uuid={parent_node_id!r}" + ) # diff --git a/packages/postgres-database/src/simcore_postgres_database/utils_projects_nodes.py b/packages/postgres-database/src/simcore_postgres_database/utils_projects_nodes.py index 9cab49d27fa9..6ad87315183f 100644 --- a/packages/postgres-database/src/simcore_postgres_database/utils_projects_nodes.py +++ b/packages/postgres-database/src/simcore_postgres_database/utils_projects_nodes.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, ConfigDict, Field from sqlalchemy.dialects.postgresql import insert as pg_insert -from .errors import ForeignKeyViolation, UniqueViolation +from .aiopg_errors import ForeignKeyViolation, UniqueViolation from .models.projects_node_to_pricing_unit import projects_node_to_pricing_unit from .models.projects_nodes import projects_nodes @@ -30,11 +30,15 @@ class ProjectNodesNodeNotFoundError(BaseProjectNodesError): class ProjectNodesNonUniqueNodeFoundError(BaseProjectNodesError): - msg_template: str = "Multiple project found containing node {node_id}. TIP: misuse, the same node ID was found in several projects." + msg_template: str = ( + "Multiple project found containing node {node_id}. TIP: misuse, the same node ID was found in several projects." + ) class ProjectNodesDuplicateNodeError(BaseProjectNodesError): - msg_template: str = "Project node already exists, you cannot have 2x the same node in the same project." + msg_template: str = ( + "Project node already exists, you cannot have 2x the same node in the same project." + ) class ProjectNodeCreate(BaseModel): diff --git a/packages/postgres-database/src/simcore_postgres_database/utils_users.py b/packages/postgres-database/src/simcore_postgres_database/utils_users.py index bd85f7d44714..26a2684197af 100644 --- a/packages/postgres-database/src/simcore_postgres_database/utils_users.py +++ b/packages/postgres-database/src/simcore_postgres_database/utils_users.py @@ -12,7 +12,7 @@ from aiopg.sa.result import RowProxy from sqlalchemy import Column -from .errors import UniqueViolation +from .aiopg_errors import UniqueViolation from .models.users import UserRole, UserStatus, users from .models.users_details import users_pre_registration_details diff --git a/packages/postgres-database/tests/products/conftest.py b/packages/postgres-database/tests/products/conftest.py index eb3d213c2495..168ba260e186 100644 --- a/packages/postgres-database/tests/products/conftest.py +++ b/packages/postgres-database/tests/products/conftest.py @@ -7,7 +7,6 @@ from collections.abc import Callable import pytest -from aiopg.sa.exc import ResourceClosedError from faker import Faker from pytest_simcore.helpers.faker_factories import random_product from simcore_postgres_database.webserver_models import products @@ -15,7 +14,7 @@ @pytest.fixture -def products_regex() -> dict: +def products_regex() -> dict[str, str]: return { "s4l": r"(^s4l[\.-])|(^sim4life\.)", "osparc": r"^osparc.", @@ -24,12 +23,12 @@ def products_regex() -> dict: @pytest.fixture -def products_names(products_regex: dict) -> list[str]: +def products_names(products_regex: dict[str, str]) -> list[str]: return list(products_regex) @pytest.fixture -def make_products_table(products_regex: dict, faker: Faker) -> Callable: +def make_products_table(products_regex: dict[str, str], faker: Faker) -> Callable: async def _make(conn) -> None: for n, (name, regex) in enumerate(products_regex.items()): @@ -37,6 +36,7 @@ async def _make(conn) -> None: pg_insert(products) .values( **random_product( + fake=faker, name=name, display_name=f"Product {name.capitalize()}", short_name=name[:3].lower(), @@ -45,6 +45,7 @@ async def _make(conn) -> None: ) ) .on_conflict_do_update( + # osparc might be already injected as default! index_elements=[products.c.name], set_={ "display_name": f"Product {name.capitalize()}", @@ -55,9 +56,7 @@ async def _make(conn) -> None: ) ) - assert result.closed + assert not result.closed assert not result.returns_rows - with pytest.raises(ResourceClosedError): - await result.scalar() return _make diff --git a/packages/postgres-database/tests/products/test_models_products.py b/packages/postgres-database/tests/products/test_models_products.py index c385cd7e7340..1f34fab7aa49 100644 --- a/packages/postgres-database/tests/products/test_models_products.py +++ b/packages/postgres-database/tests/products/test_models_products.py @@ -3,15 +3,10 @@ # pylint: disable=redefined-outer-name # pylint: disable=unused-argument - -import json from collections.abc import Callable from pathlib import Path -from pprint import pprint import sqlalchemy as sa -from aiopg.sa.engine import Engine -from aiopg.sa.result import ResultProxy, RowProxy from simcore_postgres_database.models.jinja2_templates import jinja2_templates from simcore_postgres_database.models.products import ( EmailFeedback, @@ -23,40 +18,37 @@ ) from simcore_postgres_database.webserver_models import products from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncEngine async def test_load_products( - aiopg_engine: Engine, make_products_table: Callable, products_regex: dict + asyncpg_engine: AsyncEngine, make_products_table: Callable, products_regex: dict ): exclude = { products.c.created, products.c.modified, } - async with aiopg_engine.acquire() as conn: + async with asyncpg_engine.connect() as conn: await make_products_table(conn) stmt = sa.select(*[c for c in products.columns if c not in exclude]) - result: ResultProxy = await conn.execute(stmt) - assert result.returns_rows - - rows: list[RowProxy] = await result.fetchall() + result = await conn.execute(stmt) + rows = result.fetchall() assert rows - assert { - row[products.c.name]: row[products.c.host_regex] for row in rows - } == products_regex + assert {row.name: row.host_regex for row in rows} == products_regex async def test_jinja2_templates_table( - aiopg_engine: Engine, osparc_simcore_services_dir: Path + asyncpg_engine: AsyncEngine, osparc_simcore_services_dir: Path ): templates_common_dir = ( osparc_simcore_services_dir / "web/server/src/simcore_service_webserver/templates/common" ) - async with aiopg_engine.acquire() as conn: + async with asyncpg_engine.connect() as conn: templates = [] # templates table for p in templates_common_dir.glob("*.jinja2"): @@ -105,10 +97,9 @@ async def test_jinja2_templates_table( products.c.name, jinja2_templates.c.name, products.c.short_name ).select_from(j) - result: ResultProxy = await conn.execute(stmt) - assert result.rowcount == 2 - rows = await result.fetchall() - assert sorted(r.as_tuple() for r in rows) == sorted( + result = await conn.execute(stmt) + rows = result.fetchall() + assert sorted(tuple(r) for r in rows) == sorted( [ ("osparc", "registration_email.jinja2", "osparc"), ("s4l", "registration_email.jinja2", "s4l web"), @@ -135,7 +126,7 @@ async def test_jinja2_templates_table( async def test_insert_select_product( - aiopg_engine: Engine, + asyncpg_engine: AsyncEngine, ): osparc_product = { "name": "osparc", @@ -172,9 +163,7 @@ async def test_insert_select_product( ], } - print(json.dumps(osparc_product)) - - async with aiopg_engine.acquire() as conn: + async with asyncpg_engine.begin() as conn: # writes stmt = ( pg_insert(products) @@ -188,12 +177,9 @@ async def test_insert_select_product( # reads stmt = sa.select(products).where(products.c.name == name) - row = await (await conn.execute(stmt)).fetchone() - print(row) + row = (await conn.execute(stmt)).one_or_none() assert row - pprint(dict(**row)) - assert row.manuals assert row.manuals == osparc_product["manuals"] diff --git a/packages/postgres-database/tests/products/test_products_to_templates.py b/packages/postgres-database/tests/products/test_products_to_templates.py index b1245b597d8d..9a78aaba94c0 100644 --- a/packages/postgres-database/tests/products/test_products_to_templates.py +++ b/packages/postgres-database/tests/products/test_products_to_templates.py @@ -10,12 +10,12 @@ import pytest import sqlalchemy as sa -from aiopg.sa.connection import SAConnection from faker import Faker from simcore_postgres_database.models.jinja2_templates import jinja2_templates from simcore_postgres_database.models.products import products from simcore_postgres_database.models.products_to_templates import products_to_templates from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncEngine @pytest.fixture @@ -48,54 +48,58 @@ def templates_dir( @pytest.fixture async def product_templates_in_db( - connection: SAConnection, + asyncpg_engine: AsyncEngine, make_products_table: Callable, products_names: list[str], templates_names: list[str], ): - await make_products_table(connection) - - # one version of all tempaltes - for template_name in templates_names: - await connection.execute( - jinja2_templates.insert().values( - name=template_name, content="fake template in database" + async with asyncpg_engine.begin() as conn: + await make_products_table(conn) + + # one version of all tempaltes + for template_name in templates_names: + await conn.execute( + jinja2_templates.insert().values( + name=template_name, content="fake template in database" + ) ) - ) - # only even products have templates - for product_name in products_names[0::2]: - await connection.execute( - products_to_templates.insert().values( - template_name=template_name, product_name=product_name + # only even products have templates + for product_name in products_names[0::2]: + await conn.execute( + products_to_templates.insert().values( + template_name=template_name, product_name=product_name + ) ) - ) async def test_export_and_import_table( - connection: SAConnection, + asyncpg_engine: AsyncEngine, product_templates_in_db: None, ): - exported_values = [] - excluded_names = {"created", "modified", "group_id"} - async for row in connection.execute( - sa.select(*(c for c in products.c if c.name not in excluded_names)) - ): - assert row - exported_values.append(dict(row)) - - # now just upsert them - for values in exported_values: - values["display_name"] += "-changed" - await connection.execute( - pg_insert(products) - .values(**values) - .on_conflict_do_update(index_elements=[products.c.name], set_=values) + + async with asyncpg_engine.connect() as connection: + exported_values = [] + excluded_names = {"created", "modified", "group_id"} + result = await connection.stream( + sa.select(*(c for c in products.c if c.name not in excluded_names)) ) + async for row in result: + assert row + exported_values.append(row._asdict()) + + # now just upsert them + for values in exported_values: + values["display_name"] += "-changed" + await connection.execute( + pg_insert(products) + .values(**values) + .on_conflict_do_update(index_elements=[products.c.name], set_=values) + ) async def test_create_templates_products_folder( - connection: SAConnection, + asyncpg_engine: AsyncEngine, templates_dir: Path, products_names: list[str], tmp_path: Path, @@ -121,20 +125,22 @@ async def test_create_templates_products_folder( shutil.copy(p, product_folder / p.name, follow_symlinks=False) # overrides if with files in database - async for row in connection.execute( - sa.select( - products_to_templates.c.product_name, - jinja2_templates.c.name, - jinja2_templates.c.content, + async with asyncpg_engine.connect() as conn: + result = await conn.stream( + sa.select( + products_to_templates.c.product_name, + jinja2_templates.c.name, + jinja2_templates.c.content, + ) + .select_from(products_to_templates.join(jinja2_templates)) + .where(products_to_templates.c.product_name == product_name) ) - .select_from(products_to_templates.join(jinja2_templates)) - .where(products_to_templates.c.product_name == product_name) - ): - assert row - template_path = product_folder / row.name - template_path.write_text(row.content) + async for row in result: + assert row + template_path = product_folder / row.name + template_path.write_text(row.content) - assert sorted( - product_folder / template_name for template_name in templates_names - ) == sorted(product_folder.rglob("*.*")) + assert sorted( + product_folder / template_name for template_name in templates_names + ) == sorted(product_folder.rglob("*.*")) diff --git a/packages/postgres-database/tests/products/test_utils_products.py b/packages/postgres-database/tests/products/test_utils_products.py index a1b84fe96dd8..b25ffbc0ccfe 100644 --- a/packages/postgres-database/tests/products/test_utils_products.py +++ b/packages/postgres-database/tests/products/test_utils_products.py @@ -3,43 +3,44 @@ # pylint: disable=redefined-outer-name # pylint: disable=unused-argument - -import asyncio from collections.abc import Callable import pytest import sqlalchemy as sa -from aiopg.sa.engine import Engine from simcore_postgres_database.models.groups import GroupType, groups from simcore_postgres_database.models.products import products from simcore_postgres_database.utils_products import ( + EmptyProductsError, get_default_product_name, get_or_create_product_group, - get_product_group_id, + get_product_group_id_or_none, ) +from sqlalchemy.ext.asyncio import AsyncEngine -async def test_default_product(aiopg_engine: Engine, make_products_table: Callable): - async with aiopg_engine.acquire() as conn: +async def test_default_product( + asyncpg_engine: AsyncEngine, make_products_table: Callable +): + async with asyncpg_engine.begin() as conn: await make_products_table(conn) default_product = await get_default_product_name(conn) assert default_product == "s4l" @pytest.mark.parametrize("pg_sa_engine", ["sqlModels"], indirect=True) -async def test_default_product_undefined(aiopg_engine: Engine): - async with aiopg_engine.acquire() as conn: - with pytest.raises(ValueError): +async def test_default_product_undefined(asyncpg_engine: AsyncEngine): + async with asyncpg_engine.connect() as conn: + with pytest.raises(EmptyProductsError): await get_default_product_name(conn) async def test_get_or_create_group_product( - aiopg_engine: Engine, make_products_table: Callable + asyncpg_engine: AsyncEngine, make_products_table: Callable ): - async with aiopg_engine.acquire() as conn: + async with asyncpg_engine.connect() as conn: await make_products_table(conn) - async for product_row in await conn.execute( + async for product_row in await conn.stream( sa.select(products.c.name, products.c.group_id).order_by( products.c.priority ) @@ -57,8 +58,7 @@ async def test_get_or_create_group_product( result = await conn.execute( groups.select().where(groups.c.gid == product_group_id) ) - assert result.rowcount == 1 - product_group = await result.first() + product_group = result.one() # check product's group assert product_group.type == GroupType.STANDARD @@ -78,9 +78,9 @@ async def test_get_or_create_group_product( result = await conn.execute( groups.select().where(groups.c.name == product_row.name) ) - assert result.rowcount == 1 + assert result.one() - assert product_group_id == await get_product_group_id( + assert product_group_id == await get_product_group_id_or_none( conn, product_name=product_row.name ) @@ -88,43 +88,14 @@ async def test_get_or_create_group_product( await conn.execute( groups.update().where(groups.c.gid == product_group_id).values(gid=1000) ) - product_group_id = await get_product_group_id( + product_group_id = await get_product_group_id_or_none( conn, product_name=product_row.name ) assert product_group_id == 1000 # if group is DELETED -> product.group_id=null await conn.execute(groups.delete().where(groups.c.gid == product_group_id)) - product_group_id = await get_product_group_id( + product_group_id = await get_product_group_id_or_none( conn, product_name=product_row.name ) assert product_group_id is None - - -@pytest.mark.skip( - reason="Not relevant. Will review in https://github.com/ITISFoundation/osparc-simcore/issues/3754" -) -async def test_get_or_create_group_product_concurrent( - aiopg_engine: Engine, make_products_table: Callable -): - async with aiopg_engine.acquire() as conn: - await make_products_table(conn) - - async def _auto_create_products_groups(): - async with aiopg_engine.acquire() as conn: - async for product_row in await conn.execute( - sa.select(products.c.name, products.c.group_id).order_by( - products.c.priority - ) - ): - # get or create - return await get_or_create_product_group( - conn, product_name=product_row.name - ) - return None - - tasks = [asyncio.create_task(_auto_create_products_groups()) for _ in range(5)] - - results = await asyncio.gather(*tasks) - - assert all(res == results[0] for res in results[1:]) diff --git a/packages/postgres-database/tests/test_clusters.py b/packages/postgres-database/tests/test_clusters.py index 95cd8492965f..7e643bb0fb9f 100644 --- a/packages/postgres-database/tests/test_clusters.py +++ b/packages/postgres-database/tests/test_clusters.py @@ -9,7 +9,7 @@ from aiopg.sa.engine import Engine from aiopg.sa.result import ResultProxy from pytest_simcore.helpers.faker_factories import random_user -from simcore_postgres_database.errors import ForeignKeyViolation, NotNullViolation +from simcore_postgres_database.aiopg_errors import ForeignKeyViolation, NotNullViolation from simcore_postgres_database.models.cluster_to_groups import cluster_to_groups from simcore_postgres_database.models.clusters import ClusterType, clusters from simcore_postgres_database.models.users import users @@ -41,7 +41,7 @@ async def user_group_id(aiopg_engine: Engine, user_id: int) -> int: async def test_cluster_without_owner_forbidden( - create_fake_cluster: Callable[..., Awaitable[int]] + create_fake_cluster: Callable[..., Awaitable[int]], ): with pytest.raises(NotNullViolation): await create_fake_cluster() diff --git a/packages/postgres-database/tests/test_models_payments_methods.py b/packages/postgres-database/tests/test_models_payments_methods.py index 100c0e5431b1..cb5b14ee70e4 100644 --- a/packages/postgres-database/tests/test_models_payments_methods.py +++ b/packages/postgres-database/tests/test_models_payments_methods.py @@ -10,7 +10,7 @@ from aiopg.sa.result import RowProxy from faker import Faker from pytest_simcore.helpers.faker_factories import random_payment_method -from simcore_postgres_database.errors import UniqueViolation +from simcore_postgres_database.aiopg_errors import UniqueViolation from simcore_postgres_database.models.payments_methods import ( InitPromptAckFlowState, payments_methods, diff --git a/packages/postgres-database/tests/test_models_products_prices.py b/packages/postgres-database/tests/test_models_products_prices.py index 7112f31b612b..406158af0bf9 100644 --- a/packages/postgres-database/tests/test_models_products_prices.py +++ b/packages/postgres-database/tests/test_models_products_prices.py @@ -4,58 +4,102 @@ # pylint: disable=too-many-arguments +from collections.abc import AsyncIterator + import pytest import sqlalchemy as sa -from aiopg.sa.connection import SAConnection -from aiopg.sa.result import RowProxy +import sqlalchemy.exc from faker import Faker from pytest_simcore.helpers.faker_factories import random_product -from simcore_postgres_database.errors import CheckViolation, ForeignKeyViolation from simcore_postgres_database.models.products import products from simcore_postgres_database.models.products_prices import products_prices from simcore_postgres_database.utils_products_prices import ( get_product_latest_price_info_or_none, - get_product_latest_stripe_info, + get_product_latest_stripe_info_or_none, is_payment_enabled, ) +from sqlalchemy.engine.row import Row +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine + + +@pytest.fixture +async def connection(asyncpg_engine: AsyncEngine) -> AsyncIterator[AsyncConnection]: + async with asyncpg_engine.connect() as conn: + isolation_level = await conn.get_isolation_level() + assert isolation_level == "READ COMMITTED" + yield conn @pytest.fixture -async def fake_product(connection: SAConnection) -> RowProxy: +async def fake_product(connection: AsyncConnection) -> Row: result = await connection.execute( products.insert() - .values(random_product(group_id=None)) - .returning(sa.literal_column("*")) + .values(random_product(name="tip", group_id=None)) + .returning(sa.literal_column("*")), ) - product = await result.first() - assert product is not None - return product + await connection.commit() + + async with connection.begin(): + result = await connection.execute( + products.insert() + .values(random_product(name="s4l", group_id=None)) + .returning(sa.literal_column("*")), + ) + + return result.one() async def test_creating_product_prices( - connection: SAConnection, fake_product: RowProxy, faker: Faker + asyncpg_engine: AsyncEngine, + connection: AsyncConnection, + fake_product: Row, + faker: Faker, ): # a price per product - result = await connection.execute( - products_prices.insert() - .values( - product_name=fake_product.name, - usd_per_credit=100, - comment="PO Mr X", - stripe_price_id=faker.word(), - stripe_tax_rate_id=faker.word(), + async with connection.begin(): + result = await connection.execute( + products_prices.insert() + .values( + product_name=fake_product.name, + usd_per_credit=100, + comment="PO Mr X", + stripe_price_id=faker.word(), + stripe_tax_rate_id=faker.word(), + ) + .returning(sa.literal_column("*")), ) - .returning(sa.literal_column("*")) - ) - product_prices = await result.first() - assert product_prices + got = result.one() + assert got + + # insert still NOT commited but can read from this connection + read_query = sa.select(products_prices).where( + products_prices.c.product_name == fake_product.name + ) + result = await connection.execute(read_query) + assert result.one()._asdict() == got._asdict() + + assert connection.in_transaction() is True + + # cannot read from other connection though + async with asyncpg_engine.connect() as other_connection: + result = await other_connection.execute(read_query) + assert result.one_or_none() is None + + # AFTER commit + assert connection.in_transaction() is False + async with asyncpg_engine.connect() as yet_another_connection: + result = await yet_another_connection.execute(read_query) + assert result.one()._asdict() == got._asdict() async def test_non_negative_price_not_allowed( - connection: SAConnection, fake_product: RowProxy, faker: Faker + connection: AsyncConnection, fake_product: Row, faker: Faker ): - # negative price not allowed - with pytest.raises(CheckViolation) as exc_info: + + assert not connection.in_transaction() + + # WRITE: negative price not allowed + with pytest.raises(sqlalchemy.exc.IntegrityError) as exc_info: await connection.execute( products_prices.insert().values( product_name=fake_product.name, @@ -67,46 +111,76 @@ async def test_non_negative_price_not_allowed( ) assert exc_info.value + assert connection.in_transaction() + await connection.rollback() + assert not connection.in_transaction() - # zero price is allowed - await connection.execute( - products_prices.insert().values( + # WRITE: zero price is allowed + result = await connection.execute( + products_prices.insert() + .values( product_name=fake_product.name, usd_per_credit=0, # <----- ZERO comment="PO Mr X", stripe_price_id=faker.word(), stripe_tax_rate_id=faker.word(), ) + .returning("*") ) + assert result.one() + + assert connection.in_transaction() + await connection.commit() + assert not connection.in_transaction() + + with pytest.raises(sqlalchemy.exc.ResourceClosedError): + # can only get result once! + assert result.one() + + # READ + result = await connection.execute(sa.select(products_prices)) + assert connection.in_transaction() + + assert result.one() + with pytest.raises(sqlalchemy.exc.ResourceClosedError): + # can only get result once! + assert result.one() + async def test_delete_price_constraints( - connection: SAConnection, fake_product: RowProxy, faker: Faker + connection: AsyncConnection, fake_product: Row, faker: Faker ): # products_prices - await connection.execute( - products_prices.insert().values( - product_name=fake_product.name, - usd_per_credit=10, - comment="PO Mr X", - stripe_price_id=faker.word(), - stripe_tax_rate_id=faker.word(), + async with connection.begin(): + await connection.execute( + products_prices.insert().values( + product_name=fake_product.name, + usd_per_credit=10, + comment="PO Mr X", + stripe_price_id=faker.word(), + stripe_tax_rate_id=faker.word(), + ) ) - ) + # BAD DELETE: # should not be able to delete a product w/o deleting price first - with pytest.raises(ForeignKeyViolation) as exc_info: - await connection.execute(products.delete()) + async with connection.begin(): + with pytest.raises(sqlalchemy.exc.IntegrityError, match="delete") as exc_info: + await connection.execute(products.delete()) - assert exc_info.match("delete") + # NOTE: that asyncpg.exceptions are converted to sqlalchemy.exc + # sqlalchemy.exc.IntegrityError: (sqlalchemy.dialects.postgresql.asyncpg.IntegrityError) : + assert "asyncpg.exceptions.ForeignKeyViolationError" in exc_info.value.args[0] - # this is the correct way to delete - await connection.execute(products_prices.delete()) - await connection.execute(products.delete()) + # GOOD DELETE: this is the correct way to delete + async with connection.begin(): + await connection.execute(products_prices.delete()) + await connection.execute(products.delete()) async def test_get_product_latest_price_or_none( - connection: SAConnection, fake_product: RowProxy, faker: Faker + connection: AsyncConnection, fake_product: Row, faker: Faker ): # undefined product assert ( @@ -130,29 +204,31 @@ async def test_get_product_latest_price_or_none( async def test_price_history_of_a_product( - connection: SAConnection, fake_product: RowProxy, faker: Faker + connection: AsyncConnection, fake_product: Row, faker: Faker ): # initial price - await connection.execute( - products_prices.insert().values( - product_name=fake_product.name, - usd_per_credit=1, - comment="PO Mr X", - stripe_price_id=faker.word(), - stripe_tax_rate_id=faker.word(), + async with connection.begin(): + await connection.execute( + products_prices.insert().values( + product_name=fake_product.name, + usd_per_credit=1, + comment="PO Mr X", + stripe_price_id=faker.word(), + stripe_tax_rate_id=faker.word(), + ) ) - ) # new price - await connection.execute( - products_prices.insert().values( - product_name=fake_product.name, - usd_per_credit=2, - comment="Update by Mr X", - stripe_price_id=faker.word(), - stripe_tax_rate_id=faker.word(), + async with connection.begin(): + await connection.execute( + products_prices.insert().values( + product_name=fake_product.name, + usd_per_credit=2, + comment="Update by Mr X", + stripe_price_id=faker.word(), + stripe_tax_rate_id=faker.word(), + ) ) - ) # latest is 2 USD! assert await get_product_latest_price_info_or_none( @@ -163,29 +239,33 @@ async def test_price_history_of_a_product( async def test_get_product_latest_stripe_info( - connection: SAConnection, fake_product: RowProxy, faker: Faker + connection: AsyncConnection, fake_product: Row, faker: Faker ): stripe_price_id_value = faker.word() stripe_tax_rate_id_value = faker.word() # products_prices - await connection.execute( - products_prices.insert().values( - product_name=fake_product.name, - usd_per_credit=10, - comment="PO Mr X", - stripe_price_id=stripe_price_id_value, - stripe_tax_rate_id=stripe_tax_rate_id_value, + async with connection.begin(): + await connection.execute( + products_prices.insert().values( + product_name=fake_product.name, + usd_per_credit=10, + comment="PO Mr X", + stripe_price_id=stripe_price_id_value, + stripe_tax_rate_id=stripe_tax_rate_id_value, + ) ) + + # undefined product + undefined_product_stripe_info = await get_product_latest_stripe_info_or_none( + connection, product_name="undefined" ) + assert undefined_product_stripe_info is None # defined product - product_stripe_info = await get_product_latest_stripe_info( + product_stripe_info = await get_product_latest_stripe_info_or_none( connection, product_name=fake_product.name ) + assert product_stripe_info assert product_stripe_info[0] == stripe_price_id_value assert product_stripe_info[1] == stripe_tax_rate_id_value - - # undefined product - with pytest.raises(ValueError) as exc_info: - await get_product_latest_stripe_info(connection, product_name="undefined") diff --git a/packages/postgres-database/tests/test_services_consume_filetypes.py b/packages/postgres-database/tests/test_services_consume_filetypes.py index f72799299073..efe0a083c6fc 100644 --- a/packages/postgres-database/tests/test_services_consume_filetypes.py +++ b/packages/postgres-database/tests/test_services_consume_filetypes.py @@ -15,7 +15,7 @@ FAKE_FILE_CONSUMER_SERVICES, list_supported_filetypes, ) -from simcore_postgres_database.errors import CheckViolation +from simcore_postgres_database.aiopg_errors import CheckViolation from simcore_postgres_database.models.services import services_meta_data from simcore_postgres_database.models.services_consume_filetypes import ( services_consume_filetypes, diff --git a/packages/postgres-database/tests/test_users.py b/packages/postgres-database/tests/test_users.py index 1c10636e7721..e759de7ec97e 100644 --- a/packages/postgres-database/tests/test_users.py +++ b/packages/postgres-database/tests/test_users.py @@ -11,7 +11,10 @@ from aiopg.sa.result import ResultProxy, RowProxy from faker import Faker from pytest_simcore.helpers.faker_factories import random_user -from simcore_postgres_database.errors import InvalidTextRepresentation, UniqueViolation +from simcore_postgres_database.aiopg_errors import ( + InvalidTextRepresentation, + UniqueViolation, +) from simcore_postgres_database.models.users import UserRole, UserStatus, users from simcore_postgres_database.utils_users import ( UsersRepo, diff --git a/services/api-server/src/simcore_service_api_server/db/repositories/api_keys.py b/services/api-server/src/simcore_service_api_server/db/repositories/api_keys.py index e970f819ccce..400b4e8b1c14 100644 --- a/services/api-server/src/simcore_service_api_server/db/repositories/api_keys.py +++ b/services/api-server/src/simcore_service_api_server/db/repositories/api_keys.py @@ -4,7 +4,7 @@ import sqlalchemy as sa from models_library.products import ProductName from pydantic.types import PositiveInt -from simcore_postgres_database.errors import DatabaseError +from simcore_postgres_database.aiopg_errors import DatabaseError from .. import tables as tbl from ._base import BaseRepository diff --git a/services/api-server/src/simcore_service_api_server/db/repositories/groups_extra_properties.py b/services/api-server/src/simcore_service_api_server/db/repositories/groups_extra_properties.py index 8193201baeec..847ae6926740 100644 --- a/services/api-server/src/simcore_service_api_server/db/repositories/groups_extra_properties.py +++ b/services/api-server/src/simcore_service_api_server/db/repositories/groups_extra_properties.py @@ -1,7 +1,7 @@ import logging from models_library.users import UserID -from simcore_postgres_database.errors import DatabaseError +from simcore_postgres_database.aiopg_errors import DatabaseError from simcore_postgres_database.utils_groups_extra_properties import ( GroupExtraPropertiesRepo, ) diff --git a/services/api-server/src/simcore_service_api_server/models/schemas/model_adapter.py b/services/api-server/src/simcore_service_api_server/models/schemas/model_adapter.py index a395d44be70b..8e0bcb3cba10 100644 --- a/services/api-server/src/simcore_service_api_server/models/schemas/model_adapter.py +++ b/services/api-server/src/simcore_service_api_server/models/schemas/model_adapter.py @@ -22,7 +22,7 @@ from models_library.api_schemas_webserver.licensed_items_checkouts import ( LicensedItemCheckoutRpcGet as _LicensedItemCheckoutRpcGet, ) -from models_library.api_schemas_webserver.product import ( +from models_library.api_schemas_webserver.products import ( CreditPriceGet as _GetCreditPrice, ) from models_library.api_schemas_webserver.resource_usage import ( diff --git a/services/api-server/tests/unit/test_credits.py b/services/api-server/tests/unit/test_credits.py index 4f7dd8b41e5e..f9548949b818 100644 --- a/services/api-server/tests/unit/test_credits.py +++ b/services/api-server/tests/unit/test_credits.py @@ -2,7 +2,7 @@ from fastapi import status from httpx import AsyncClient, BasicAuth -from models_library.api_schemas_webserver.product import CreditPriceGet +from models_library.api_schemas_webserver.products import CreditPriceGet from pytest_simcore.helpers.httpx_calls_capture_models import CreateRespxMockCallback from simcore_service_api_server._meta import API_VTAG diff --git a/services/catalog/src/simcore_service_catalog/db/repositories/products.py b/services/catalog/src/simcore_service_catalog/db/repositories/products.py index 57b036150d21..ea59f9dab05e 100644 --- a/services/catalog/src/simcore_service_catalog/db/repositories/products.py +++ b/services/catalog/src/simcore_service_catalog/db/repositories/products.py @@ -5,6 +5,6 @@ class ProductsRepository(BaseRepository): async def get_default_product_name(self) -> str: - async with self.db_engine.begin() as conn: + async with self.db_engine.connect() as conn: product_name: str = await get_default_product_name(conn) return product_name diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py index 46cc7669cde2..eac76d5d945a 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py @@ -10,7 +10,7 @@ from models_library.users import UserID from models_library.utils.fastapi_encoders import jsonable_encoder from pydantic import PositiveInt -from simcore_postgres_database.errors import ForeignKeyViolation +from simcore_postgres_database.aiopg_errors import ForeignKeyViolation from sqlalchemy.sql import or_ from sqlalchemy.sql.elements import literal_column from sqlalchemy.sql.expression import desc diff --git a/services/payments/src/simcore_service_payments/db/payments_methods_repo.py b/services/payments/src/simcore_service_payments/db/payments_methods_repo.py index 4eb43b667b13..cea7b8e6158e 100644 --- a/services/payments/src/simcore_service_payments/db/payments_methods_repo.py +++ b/services/payments/src/simcore_service_payments/db/payments_methods_repo.py @@ -1,6 +1,6 @@ import datetime -import simcore_postgres_database.errors as db_errors +import simcore_postgres_database.aiopg_errors as db_errors import sqlalchemy as sa from arrow import utcnow from models_library.api_schemas_payments.errors import ( diff --git a/services/payments/src/simcore_service_payments/db/payments_transactions_repo.py b/services/payments/src/simcore_service_payments/db/payments_transactions_repo.py index 8b2eef6f2286..d7f6b893668e 100644 --- a/services/payments/src/simcore_service_payments/db/payments_transactions_repo.py +++ b/services/payments/src/simcore_service_payments/db/payments_transactions_repo.py @@ -13,7 +13,7 @@ from models_library.users import UserID from models_library.wallets import WalletID from pydantic import HttpUrl, PositiveInt, TypeAdapter -from simcore_postgres_database import errors as pg_errors +from simcore_postgres_database import aiopg_errors as pg_errors from simcore_postgres_database.models.payments_transactions import ( PaymentTransactionState, payments_transactions, diff --git a/services/web/server/src/simcore_service_webserver/api_keys/_controller_rest.py b/services/web/server/src/simcore_service_webserver/api_keys/_controller_rest.py index c3b81c63cd31..963b00bdf706 100644 --- a/services/web/server/src/simcore_service_webserver/api_keys/_controller_rest.py +++ b/services/web/server/src/simcore_service_webserver/api_keys/_controller_rest.py @@ -24,7 +24,7 @@ from ..utils_aiohttp import envelope_json_response from . import _service from ._exceptions_handlers import handle_plugin_requests_exceptions -from ._models import ApiKey +from .models import ApiKey _logger = logging.getLogger(__name__) diff --git a/services/web/server/src/simcore_service_webserver/api_keys/_controller_rpc.py b/services/web/server/src/simcore_service_webserver/api_keys/_controller_rpc.py index 3dd04600cfad..59c91c3be59b 100644 --- a/services/web/server/src/simcore_service_webserver/api_keys/_controller_rpc.py +++ b/services/web/server/src/simcore_service_webserver/api_keys/_controller_rpc.py @@ -10,8 +10,8 @@ from ..rabbitmq import get_rabbitmq_rpc_server from . import _service -from ._models import ApiKey from .errors import ApiKeyNotFoundError +from .models import ApiKey router = RPCRouter() diff --git a/services/web/server/src/simcore_service_webserver/api_keys/_repository.py b/services/web/server/src/simcore_service_webserver/api_keys/_repository.py index 0002ceff1e57..bba0097403d9 100644 --- a/services/web/server/src/simcore_service_webserver/api_keys/_repository.py +++ b/services/web/server/src/simcore_service_webserver/api_keys/_repository.py @@ -7,13 +7,16 @@ from models_library.products import ProductName from models_library.users import UserID from simcore_postgres_database.models.api_keys import api_keys -from simcore_postgres_database.utils_repos import transaction_context +from simcore_postgres_database.utils_repos import ( + pass_or_acquire_connection, + transaction_context, +) from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncConnection from ..db.plugin import get_asyncpg_engine -from ._models import ApiKey from .errors import ApiKeyDuplicatedDisplayNameError +from .models import ApiKey _logger = logging.getLogger(__name__) @@ -45,7 +48,7 @@ async def create_api_key( ) result = await conn.stream(stmt) - row = await result.first() + row = await result.one() return ApiKey( id=f"{row.id}", # NOTE See: https://github.com/ITISFoundation/osparc-simcore/issues/6919 @@ -111,7 +114,7 @@ async def list_api_keys( user_id: UserID, product_name: ProductName, ) -> list[ApiKey]: - async with transaction_context(get_asyncpg_engine(app), connection) as conn: + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: stmt = sa.select(api_keys.c.id, api_keys.c.display_name).where( (api_keys.c.user_id == user_id) & (api_keys.c.product_name == product_name) ) @@ -136,7 +139,7 @@ async def get_api_key( user_id: UserID, product_name: ProductName, ) -> ApiKey | None: - async with transaction_context(get_asyncpg_engine(app), connection) as conn: + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: stmt = sa.select(api_keys).where( ( api_keys.c.id == int(api_key_id) @@ -145,8 +148,8 @@ async def get_api_key( & (api_keys.c.product_name == product_name) ) - result = await conn.stream(stmt) - row = await result.first() + result = await conn.execute(stmt) + row = result.one_or_none() return ( ApiKey( @@ -180,7 +183,7 @@ async def delete_api_key( await conn.execute(stmt) -async def prune_expired( +async def delete_expired_api_keys( app: web.Application, connection: AsyncConnection | None = None ) -> list[str]: async with transaction_context(get_asyncpg_engine(app), connection) as conn: @@ -192,6 +195,6 @@ async def prune_expired( ) .returning(api_keys.c.display_name) ) - result = await conn.stream(stmt) - rows = [row async for row in result] + result = await conn.execute(stmt) + rows = result.fetchall() return [r.display_name for r in rows] diff --git a/services/web/server/src/simcore_service_webserver/api_keys/_service.py b/services/web/server/src/simcore_service_webserver/api_keys/_service.py index 4d7cdcb43dce..d5648e43060c 100644 --- a/services/web/server/src/simcore_service_webserver/api_keys/_service.py +++ b/services/web/server/src/simcore_service_webserver/api_keys/_service.py @@ -9,8 +9,8 @@ from servicelib.utils_secrets import generate_token_secret_key from . import _repository -from ._models import ApiKey from .errors import ApiKeyNotFoundError +from .models import ApiKey _PUNCTUATION_REGEX = re.compile( pattern="[" + re.escape(string.punctuation.replace("_", "")) + "]" @@ -32,8 +32,8 @@ async def create_api_key( *, user_id: UserID, product_name: ProductName, - display_name=str, - expiration=dt.timedelta, + display_name: str, + expiration: dt.timedelta | None, ) -> ApiKey: api_key, api_secret = _generate_api_key_and_secret(display_name) @@ -119,5 +119,5 @@ async def delete_api_key( async def prune_expired_api_keys(app: web.Application) -> list[str]: - names: list[str] = await _repository.prune_expired(app) + names: list[str] = await _repository.delete_expired_api_keys(app) return names diff --git a/services/web/server/src/simcore_service_webserver/api_keys/_models.py b/services/web/server/src/simcore_service_webserver/api_keys/models.py similarity index 100% rename from services/web/server/src/simcore_service_webserver/api_keys/_models.py rename to services/web/server/src/simcore_service_webserver/api_keys/models.py diff --git a/services/web/server/src/simcore_service_webserver/constants.py b/services/web/server/src/simcore_service_webserver/constants.py index b2997155ba4c..6c0dae060da9 100644 --- a/services/web/server/src/simcore_service_webserver/constants.py +++ b/services/web/server/src/simcore_service_webserver/constants.py @@ -38,15 +38,19 @@ # main index route name = front-end INDEX_RESOURCE_NAME: Final[str] = "get_cached_frontend_index" -MSG_UNDER_DEVELOPMENT: Final[ - str -] = "Under development. Use WEBSERVER_DEV_FEATURES_ENABLED=1 to enable current implementation" +MSG_UNDER_DEVELOPMENT: Final[str] = ( + "Under development. Use WEBSERVER_DEV_FEATURES_ENABLED=1 to enable current implementation" +) # Request storage keys RQ_PRODUCT_KEY: Final[str] = f"{__name__}.RQ_PRODUCT_KEY" +MSG_TRY_AGAIN_OR_SUPPORT: Final[str] = ( + "Please try again shortly. If the issue persists, contact support." +) + __all__: tuple[str, ...] = ( "APP_AIOPG_ENGINE_KEY", "APP_CONFIG_KEY", diff --git a/services/web/server/src/simcore_service_webserver/db/_aiopg.py b/services/web/server/src/simcore_service_webserver/db/_aiopg.py index 4a45a0a00fbb..9d9feea1f807 100644 --- a/services/web/server/src/simcore_service_webserver/db/_aiopg.py +++ b/services/web/server/src/simcore_service_webserver/db/_aiopg.py @@ -15,7 +15,7 @@ from servicelib.aiohttp.application_keys import APP_AIOPG_ENGINE_KEY from servicelib.logging_utils import log_context from servicelib.retry_policies import PostgresRetryPolicyUponInitialization -from simcore_postgres_database.errors import DBAPIError +from simcore_postgres_database.aiopg_errors import DBAPIError from simcore_postgres_database.utils_aiopg import ( DBMigrationError, close_engine, diff --git a/services/web/server/src/simcore_service_webserver/db/base_repository.py b/services/web/server/src/simcore_service_webserver/db/base_repository.py index 7c32e6182778..fc735e97254a 100644 --- a/services/web/server/src/simcore_service_webserver/db/base_repository.py +++ b/services/web/server/src/simcore_service_webserver/db/base_repository.py @@ -1,33 +1,26 @@ +from dataclasses import dataclass +from typing import Self + from aiohttp import web -from aiopg.sa.engine import Engine from models_library.users import UserID +from sqlalchemy.ext.asyncio import AsyncEngine from ..constants import RQT_USERID_KEY -from . import _aiopg +from . import _asyncpg +@dataclass(frozen=True) class BaseRepository: - def __init__(self, engine: Engine, user_id: UserID | None = None): - self._engine = engine - self._user_id = user_id - - assert isinstance(self._engine, Engine) # nosec + engine: AsyncEngine + user_id: UserID | None = None @classmethod - def create_from_request(cls, request: web.Request): + def create_from_request(cls, request: web.Request) -> Self: return cls( - engine=_aiopg.get_database_engine(request.app), + engine=_asyncpg.get_async_engine(request.app), user_id=request.get(RQT_USERID_KEY), ) @classmethod - def create_from_app(cls, app: web.Application): - return cls(engine=_aiopg.get_database_engine(app), user_id=None) - - @property - def engine(self) -> Engine: - return self._engine - - @property - def user_id(self) -> int | None: - return self._user_id + def create_from_app(cls, app: web.Application) -> Self: + return cls(engine=_asyncpg.get_async_engine(app)) diff --git a/services/web/server/src/simcore_service_webserver/folders/_folders_repository.py b/services/web/server/src/simcore_service_webserver/folders/_folders_repository.py index 072e944688ea..f57a1c6df848 100644 --- a/services/web/server/src/simcore_service_webserver/folders/_folders_repository.py +++ b/services/web/server/src/simcore_service_webserver/folders/_folders_repository.py @@ -275,7 +275,7 @@ async def list_trashed_folders( NOTE: this is app-wide i.e. no product, user or workspace filtered TODO: check with MD about workspaces """ - base_query = sql.select(_FOLDER_DB_MODEL_COLS).where( + base_query = sql.select(*_FOLDER_DB_MODEL_COLS).where( folders_v2.c.trashed.is_not(None) ) @@ -306,7 +306,9 @@ async def list_trashed_folders( def _create_base_select_query(folder_id: FolderID, product_name: ProductName) -> Select: - return sql.select(*_FOLDER_DB_MODEL_COLS,).where( + return sql.select( + *_FOLDER_DB_MODEL_COLS, + ).where( (folders_v2.c.product_name == product_name) & (folders_v2.c.folder_id == folder_id) ) diff --git a/services/web/server/src/simcore_service_webserver/garbage_collector/_core_guests.py b/services/web/server/src/simcore_service_webserver/garbage_collector/_core_guests.py index f89278ead785..1d44f39e793d 100644 --- a/services/web/server/src/simcore_service_webserver/garbage_collector/_core_guests.py +++ b/services/web/server/src/simcore_service_webserver/garbage_collector/_core_guests.py @@ -8,7 +8,7 @@ from models_library.users import UserID, UserNameID from redis.asyncio import Redis from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE -from simcore_postgres_database.errors import DatabaseError +from simcore_postgres_database.aiopg_errors import DatabaseError from simcore_postgres_database.models.users import UserRole from ..projects.db import ProjectDBAPI diff --git a/services/web/server/src/simcore_service_webserver/garbage_collector/_core_utils.py b/services/web/server/src/simcore_service_webserver/garbage_collector/_core_utils.py index 6a85dc835392..21ee7e294015 100644 --- a/services/web/server/src/simcore_service_webserver/garbage_collector/_core_utils.py +++ b/services/web/server/src/simcore_service_webserver/garbage_collector/_core_utils.py @@ -5,7 +5,7 @@ from models_library.groups import Group, GroupID, GroupType from models_library.projects import ProjectID from models_library.users import UserID -from simcore_postgres_database.errors import DatabaseError +from simcore_postgres_database.aiopg_errors import DatabaseError from ..groups.api import get_group_from_gid from ..projects.api import ( diff --git a/services/web/server/src/simcore_service_webserver/groups/_groups_repository.py b/services/web/server/src/simcore_service_webserver/groups/_groups_repository.py index 0d8b24b83fe8..05c292caf8d1 100644 --- a/services/web/server/src/simcore_service_webserver/groups/_groups_repository.py +++ b/services/web/server/src/simcore_service_webserver/groups/_groups_repository.py @@ -18,9 +18,9 @@ StandardGroupUpdate, ) from models_library.users import UserID -from simcore_postgres_database.errors import UniqueViolation +from simcore_postgres_database.aiopg_errors import UniqueViolation from simcore_postgres_database.models.users import users -from simcore_postgres_database.utils_products import execute_get_or_create_product_group +from simcore_postgres_database.utils_products import get_or_create_product_group from simcore_postgres_database.utils_repos import ( pass_or_acquire_connection, transaction_context, @@ -173,7 +173,6 @@ async def get_all_user_groups_with_read_access( *, user_id: UserID, ) -> GroupsByTypeTuple: - """ Returns the user primary group, standard groups and the all group """ @@ -758,7 +757,7 @@ async def auto_add_user_to_product_group( product_name: str, ) -> GroupID: async with transaction_context(get_asyncpg_engine(app), connection) as conn: - product_group_id: GroupID = await execute_get_or_create_product_group( + product_group_id: GroupID = await get_or_create_product_group( conn, product_name ) diff --git a/services/web/server/src/simcore_service_webserver/invitations/_rest.py b/services/web/server/src/simcore_service_webserver/invitations/_rest.py index ebbd53495037..ec0b8cbb1c0b 100644 --- a/services/web/server/src/simcore_service_webserver/invitations/_rest.py +++ b/services/web/server/src/simcore_service_webserver/invitations/_rest.py @@ -2,7 +2,7 @@ from aiohttp import web from models_library.api_schemas_invitations.invitations import ApiInvitationInputs -from models_library.api_schemas_webserver.product import ( +from models_library.api_schemas_webserver.products import ( InvitationGenerate, InvitationGenerated, ) diff --git a/services/web/server/src/simcore_service_webserver/login/_2fa_api.py b/services/web/server/src/simcore_service_webserver/login/_2fa_api.py index 8ab315902fb5..cda2bc1721d2 100644 --- a/services/web/server/src/simcore_service_webserver/login/_2fa_api.py +++ b/services/web/server/src/simcore_service_webserver/login/_2fa_api.py @@ -1,4 +1,4 @@ -""" two-factor-authentication utils +"""two-factor-authentication utils Currently includes two parts: @@ -10,6 +10,7 @@ import asyncio import logging +import twilio.rest # type: ignore[import-untyped] from aiohttp import web from models_library.users import UserID from pydantic import BaseModel, Field @@ -17,7 +18,6 @@ from servicelib.utils_secrets import generate_passcode from settings_library.twilio import TwilioSettings from twilio.base.exceptions import TwilioException # type: ignore[import-untyped] -from twilio.rest import Client # type: ignore[import-untyped] from ..login.errors import SendingVerificationEmailError, SendingVerificationSmsError from ..products.models import Product @@ -118,7 +118,8 @@ def _sender(): # # SEE https://www.twilio.com/docs/sms/quickstart/python # - client = Client( + # NOTE: this is mocked + client = twilio.rest.Client( twilio_auth.TWILIO_ACCOUNT_SID, twilio_auth.TWILIO_AUTH_TOKEN ) message = client.messages.create(**create_kwargs) diff --git a/services/web/server/src/simcore_service_webserver/login/handlers_confirmation.py b/services/web/server/src/simcore_service_webserver/login/handlers_confirmation.py index 886ee6c355aa..2568c3628fca 100644 --- a/services/web/server/src/simcore_service_webserver/login/handlers_confirmation.py +++ b/services/web/server/src/simcore_service_webserver/login/handlers_confirmation.py @@ -23,7 +23,7 @@ ) from servicelib.logging_errors import create_troubleshotting_log_kwargs from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON -from simcore_postgres_database.errors import UniqueViolation +from simcore_postgres_database.aiopg_errors import UniqueViolation from yarl import URL from ..products import products_web diff --git a/services/web/server/src/simcore_service_webserver/payments/_methods_db.py b/services/web/server/src/simcore_service_webserver/payments/_methods_db.py index 3b2bcf8ede81..135eaf41a9e2 100644 --- a/services/web/server/src/simcore_service_webserver/payments/_methods_db.py +++ b/services/web/server/src/simcore_service_webserver/payments/_methods_db.py @@ -1,7 +1,7 @@ import datetime import logging -import simcore_postgres_database.errors as db_errors +import simcore_postgres_database.aiopg_errors as db_errors import sqlalchemy as sa from aiohttp import web from aiopg.sa.result import ResultProxy diff --git a/services/web/server/src/simcore_service_webserver/payments/_rpc_invoice.py b/services/web/server/src/simcore_service_webserver/payments/_rpc_invoice.py index 3f1c72556383..d799d04fe6f7 100644 --- a/services/web/server/src/simcore_service_webserver/payments/_rpc_invoice.py +++ b/services/web/server/src/simcore_service_webserver/payments/_rpc_invoice.py @@ -4,11 +4,12 @@ from models_library.api_schemas_webserver import WEBSERVER_RPC_NAMESPACE from models_library.emails import LowerCaseEmailStr from models_library.payments import InvoiceDataGet, UserInvoiceAddress -from models_library.products import CreditResultGet, ProductName, ProductStripeInfoGet +from models_library.products import ProductName from models_library.users import UserID from servicelib.rabbitmq import RPCRouter from ..products import products_service +from ..products.models import CreditResult from ..rabbitmq import get_rabbitmq_rpc_server from ..users.api import get_user_display_and_id_names, get_user_invoice_address @@ -23,11 +24,11 @@ async def get_invoice_data( dollar_amount: Decimal, product_name: ProductName, ) -> InvoiceDataGet: - credit_result_get: CreditResultGet = await products_service.get_credit_amount( + credit_result: CreditResult = await products_service.get_credit_amount( app, dollar_amount=dollar_amount, product_name=product_name ) - product_stripe_info_get: ProductStripeInfoGet = ( - await products_service.get_product_stripe_info(app, product_name=product_name) + product_stripe_info = await products_service.get_product_stripe_info( + app, product_name=product_name ) user_invoice_address: UserInvoiceAddress = await get_user_invoice_address( app, user_id=user_id @@ -35,9 +36,9 @@ async def get_invoice_data( user_info = await get_user_display_and_id_names(app, user_id=user_id) return InvoiceDataGet( - credit_amount=credit_result_get.credit_amount, - stripe_price_id=product_stripe_info_get.stripe_price_id, - stripe_tax_rate_id=product_stripe_info_get.stripe_tax_rate_id, + credit_amount=credit_result.credit_amount, + stripe_price_id=product_stripe_info.stripe_price_id, + stripe_tax_rate_id=product_stripe_info.stripe_tax_rate_id, user_invoice_address=user_invoice_address, user_display_name=user_info.full_name, user_email=LowerCaseEmailStr(user_info.email), diff --git a/services/web/server/src/simcore_service_webserver/products/_controller/__init__.py b/services/web/server/src/simcore_service_webserver/products/_controller/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/services/web/server/src/simcore_service_webserver/products/_rest.py b/services/web/server/src/simcore_service_webserver/products/_controller/rest.py similarity index 76% rename from services/web/server/src/simcore_service_webserver/products/_rest.py rename to services/web/server/src/simcore_service_webserver/products/_controller/rest.py index 621eca41798e..77c72afe3b0c 100644 --- a/services/web/server/src/simcore_service_webserver/products/_rest.py +++ b/services/web/server/src/simcore_service_webserver/products/_controller/rest.py @@ -1,21 +1,22 @@ import logging from aiohttp import web -from models_library.api_schemas_webserver.product import ( +from models_library.api_schemas_webserver.products import ( CreditPriceGet, ProductGet, ProductUIGet, ) from servicelib.aiohttp.requests_validation import parse_request_path_parameters_as -from .._meta import API_VTAG as VTAG -from ..login.decorators import login_required -from ..security.decorators import permission_required -from ..utils_aiohttp import envelope_json_response -from . import _service, products_web -from ._repository import ProductRepository -from ._rest_schemas import ProductsRequestContext, ProductsRequestParams -from .models import Product +from ..._meta import API_VTAG as VTAG +from ...login.decorators import login_required +from ...security.decorators import permission_required +from ...utils_aiohttp import envelope_json_response +from .. import _service, products_web +from .._repository import ProductRepository +from ..models import Product +from .rest_exceptions import handle_rest_requests_exceptions +from .rest_schemas import ProductsRequestContext, ProductsRequestParams routes = web.RouteTableDef() @@ -26,6 +27,7 @@ @routes.get(f"/{VTAG}/credits-price", name="get_current_product_price") @login_required @permission_required("product.price.read") +@handle_rest_requests_exceptions async def _get_current_product_price(request: web.Request): req_ctx = ProductsRequestContext.model_validate(request) price_info = await products_web.get_current_product_credit_price_info(request) @@ -45,6 +47,7 @@ async def _get_current_product_price(request: web.Request): @routes.get(f"/{VTAG}/products/{{product_name}}", name="get_product") @login_required @permission_required("product.details.*") +@handle_rest_requests_exceptions async def _get_product(request: web.Request): req_ctx = ProductsRequestContext.model_validate(request) path_params = parse_request_path_parameters_as(ProductsRequestParams, request) @@ -54,10 +57,7 @@ async def _get_product(request: web.Request): else: product_name = path_params.product_name - try: - product: Product = _service.get_product(request.app, product_name=product_name) - except KeyError as err: - raise web.HTTPNotFound(reason=f"{product_name=} not found") from err + product: Product = _service.get_product(request.app, product_name=product_name) assert "extra" in ProductGet.model_config # nosec assert ProductGet.model_config["extra"] == "ignore" # nosec @@ -68,6 +68,7 @@ async def _get_product(request: web.Request): @routes.get(f"/{VTAG}/products/current/ui", name="get_current_product_ui") @login_required @permission_required("product.ui.read") +@handle_rest_requests_exceptions async def _get_current_product_ui(request: web.Request): req_ctx = ProductsRequestContext.model_validate(request) product_name = req_ctx.product_name diff --git a/services/web/server/src/simcore_service_webserver/products/_controller/rest_exceptions.py b/services/web/server/src/simcore_service_webserver/products/_controller/rest_exceptions.py new file mode 100644 index 000000000000..a9e8cb13f00c --- /dev/null +++ b/services/web/server/src/simcore_service_webserver/products/_controller/rest_exceptions.py @@ -0,0 +1,26 @@ +from servicelib.aiohttp import status + +from ...constants import MSG_TRY_AGAIN_OR_SUPPORT +from ...exception_handling import ( + ExceptionToHttpErrorMap, + HttpErrorInfo, + exception_handling_decorator, + to_exceptions_handlers_map, +) +from ..errors import MissingStripeConfigError, ProductNotFoundError + +_TO_HTTP_ERROR_MAP: ExceptionToHttpErrorMap = { + ProductNotFoundError: HttpErrorInfo( + status.HTTP_404_NOT_FOUND, + "{product_name} was not found", + ), + MissingStripeConfigError: HttpErrorInfo( + status.HTTP_503_SERVICE_UNAVAILABLE, + "{product_name} service is currently unavailable." + MSG_TRY_AGAIN_OR_SUPPORT, + ), +} + + +handle_rest_requests_exceptions = exception_handling_decorator( + to_exceptions_handlers_map(_TO_HTTP_ERROR_MAP) +) diff --git a/services/web/server/src/simcore_service_webserver/products/_rest_schemas.py b/services/web/server/src/simcore_service_webserver/products/_controller/rest_schemas.py similarity index 94% rename from services/web/server/src/simcore_service_webserver/products/_rest_schemas.py rename to services/web/server/src/simcore_service_webserver/products/_controller/rest_schemas.py index 5aa0938bba00..6a4ac2100b17 100644 --- a/services/web/server/src/simcore_service_webserver/products/_rest_schemas.py +++ b/services/web/server/src/simcore_service_webserver/products/_controller/rest_schemas.py @@ -9,7 +9,7 @@ from pydantic import Field from servicelib.request_keys import RQT_USERID_KEY -from ..constants import RQ_PRODUCT_KEY +from ...constants import RQ_PRODUCT_KEY routes = web.RouteTableDef() diff --git a/services/web/server/src/simcore_service_webserver/products/_controller/rpc.py b/services/web/server/src/simcore_service_webserver/products/_controller/rpc.py new file mode 100644 index 000000000000..852cf2e4f8c0 --- /dev/null +++ b/services/web/server/src/simcore_service_webserver/products/_controller/rpc.py @@ -0,0 +1,38 @@ +from decimal import Decimal + +from aiohttp import web +from models_library.api_schemas_webserver import WEBSERVER_RPC_NAMESPACE +from models_library.api_schemas_webserver.products import CreditResultRpcGet +from models_library.products import ProductName +from servicelib.rabbitmq import RPCRouter + +from ...constants import APP_SETTINGS_KEY +from ...rabbitmq import get_rabbitmq_rpc_server, setup_rabbitmq +from .. import _service +from .._models import CreditResult + +router = RPCRouter() + + +@router.expose() +async def get_credit_amount( + app: web.Application, + *, + dollar_amount: Decimal, + product_name: ProductName, +) -> CreditResultRpcGet: + credit_result: CreditResult = await _service.get_credit_amount( + app, dollar_amount=dollar_amount, product_name=product_name + ) + return CreditResultRpcGet.model_validate(credit_result, from_attributes=True) + + +async def _register_rpc_routes_on_startup(app: web.Application): + rpc_server = get_rabbitmq_rpc_server(app) + await rpc_server.register_router(router, WEBSERVER_RPC_NAMESPACE, app) + + +def setup_rpc(app: web.Application): + setup_rabbitmq(app) + if app[APP_SETTINGS_KEY].WEBSERVER_RABBITMQ: + app.on_startup.append(_register_rpc_routes_on_startup) diff --git a/services/web/server/src/simcore_service_webserver/products/_models.py b/services/web/server/src/simcore_service_webserver/products/_models.py index 22caaa8408f0..dbab8b60a9bf 100644 --- a/services/web/server/src/simcore_service_webserver/products/_models.py +++ b/services/web/server/src/simcore_service_webserver/products/_models.py @@ -1,6 +1,8 @@ import logging import re import string +from dataclasses import dataclass +from decimal import Decimal from typing import Annotated, Any from models_library.basic_regex import ( @@ -9,7 +11,7 @@ ) from models_library.basic_types import NonNegativeDecimal from models_library.emails import LowerCaseEmailStr -from models_library.products import ProductName +from models_library.products import ProductName, StripePriceID, StripeTaxRateID from models_library.utils.change_case import snake_to_camel from pydantic import ( BaseModel, @@ -37,6 +39,25 @@ _logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class CreditResult: + product_name: ProductName + credit_amount: Decimal + + +@dataclass(frozen=True) +class ProductStripeInfo: + stripe_price_id: StripePriceID + stripe_tax_rate_id: StripeTaxRateID + + +@dataclass(frozen=True) +class PaymentFields: + enabled: bool + credits_per_usd: Decimal | None + min_payment_amount_usd: Decimal | None + + class Product(BaseModel): """Model used to parse a row of pg product's table diff --git a/services/web/server/src/simcore_service_webserver/products/_repository.py b/services/web/server/src/simcore_service_webserver/products/_repository.py index 7a45e7ebb392..16a677b0c829 100644 --- a/services/web/server/src/simcore_service_webserver/products/_repository.py +++ b/services/web/server/src/simcore_service_webserver/products/_repository.py @@ -1,23 +1,33 @@ import logging -from collections.abc import AsyncIterator from decimal import Decimal -from typing import Any, NamedTuple +from typing import Any import sqlalchemy as sa -from aiopg.sa.connection import SAConnection -from aiopg.sa.result import ResultProxy, RowProxy -from models_library.products import ProductName, ProductStripeInfoGet +from models_library.groups import GroupID +from models_library.products import ProductName from simcore_postgres_database.constants import QUANTIZE_EXP_ARG from simcore_postgres_database.models.jinja2_templates import jinja2_templates +from simcore_postgres_database.models.products import products +from simcore_postgres_database.utils_products import ( + get_default_product_name, + get_or_create_product_group, +) from simcore_postgres_database.utils_products_prices import ( ProductPriceInfo, get_product_latest_price_info_or_none, - get_product_latest_stripe_info, + get_product_latest_stripe_info_or_none, +) +from simcore_postgres_database.utils_repos import ( + pass_or_acquire_connection, + transaction_context, ) +from simcore_service_webserver.constants import FRONTEND_APPS_AVAILABLE +from sqlalchemy.engine import Row +from sqlalchemy.ext.asyncio import AsyncConnection +from ..constants import FRONTEND_APPS_AVAILABLE from ..db.base_repository import BaseRepository -from ..db.models import products -from .models import Product +from ._models import PaymentFields, Product, ProductStripeInfo _logger = logging.getLogger(__name__) @@ -45,30 +55,36 @@ products.c.group_id, ] +assert {column.name for column in _PRODUCTS_COLUMNS}.issubset( # nosec + set(Product.model_fields) +) -class PaymentFieldsTuple(NamedTuple): - enabled: bool - credits_per_usd: Decimal | None - min_payment_amount_usd: Decimal | None + +def _to_domain(products_row: Row, payments: PaymentFields) -> Product: + return Product( + **products_row._asdict(), + is_payment_enabled=payments.enabled, + credits_per_usd=payments.credits_per_usd, + ) -async def get_product_payment_fields( - conn: SAConnection, product_name: ProductName -) -> PaymentFieldsTuple: +async def _get_product_payment_fields( + conn: AsyncConnection, product_name: ProductName +) -> PaymentFields: price_info = await get_product_latest_price_info_or_none( conn, product_name=product_name ) if price_info is None or price_info.usd_per_credit == 0: - return PaymentFieldsTuple( + return PaymentFields( enabled=False, credits_per_usd=None, min_payment_amount_usd=None, ) - assert price_info.usd_per_credit > 0 - assert price_info.min_payment_amount_usd > 0 + assert price_info.usd_per_credit > 0 # nosec + assert price_info.min_payment_amount_usd > 0 # nosec - return PaymentFieldsTuple( + return PaymentFields( enabled=True, credits_per_usd=Decimal(1 / price_info.usd_per_credit).quantize( QUANTIZE_EXP_ARG @@ -77,93 +93,141 @@ async def get_product_payment_fields( ) -async def iter_products(conn: SAConnection) -> AsyncIterator[ResultProxy]: - """Iterates on products sorted by priority i.e. the first is considered the default""" - async for row in conn.execute( - sa.select(*_PRODUCTS_COLUMNS).order_by(products.c.priority) - ): - assert row # nosec - yield row +class ProductRepository(BaseRepository): + async def list_products( + self, + connection: AsyncConnection | None = None, + ) -> list[Product]: + """ + Raises: + ValidationError:if products are not setup correctly in the database + """ + app_products: list[Product] = [] -class ProductRepository(BaseRepository): - async def list_products_names(self) -> list[ProductName]: - async with self.engine.acquire() as conn: - query = sa.select(products.c.name).order_by(products.c.priority) - result = await conn.execute(query) - rows = await result.fetchall() - return [ProductName(row.name) for row in rows] + query = sa.select(*_PRODUCTS_COLUMNS).order_by(products.c.priority) - async def get_product(self, product_name: str) -> Product | None: - async with self.engine.acquire() as conn: - result: ResultProxy = await conn.execute( - sa.select(*_PRODUCTS_COLUMNS).where(products.c.name == product_name) - ) - row: RowProxy | None = await result.first() - if row: - # NOTE: MD Observation: Currently we are not defensive, we assume automatically - # that the product is not billable when there is no product in the products_prices table - # or it's price is 0. We should change it and always assume that the product is billable, unless - # explicitely stated that it is free - payments = await get_product_payment_fields(conn, product_name=row.name) - return Product( - **dict(row.items()), - is_payment_enabled=payments.enabled, - credits_per_usd=payments.credits_per_usd, + async with pass_or_acquire_connection(self.engine, connection) as conn: + rows = await conn.stream(query) + async for row in rows: + name = row.name + payments = await _get_product_payment_fields(conn, product_name=name) + app_products.append(_to_domain(row, payments)) + + assert name in FRONTEND_APPS_AVAILABLE # nosec + + return app_products + + async def list_products_names( + self, + connection: AsyncConnection | None = None, + ) -> list[ProductName]: + query = sa.select(products.c.name).order_by(products.c.priority) + + async with pass_or_acquire_connection(self.engine, connection) as conn: + rows = await conn.stream(query) + return [ProductName(row.name) async for row in rows] + + async def get_product( + self, product_name: str, connection: AsyncConnection | None = None + ) -> Product | None: + query = sa.select(*_PRODUCTS_COLUMNS).where(products.c.name == product_name) + + async with pass_or_acquire_connection(self.engine, connection) as conn: + result = await conn.execute(query) + if row := result.one_or_none(): + payments = await _get_product_payment_fields( + conn, product_name=row.name ) + return _to_domain(row, payments) return None + async def get_default_product_name( + self, connection: AsyncConnection | None = None + ) -> ProductName: + async with pass_or_acquire_connection(self.engine, connection) as conn: + return await get_default_product_name(conn) + async def get_product_latest_price_info_or_none( - self, product_name: str + self, product_name: str, connection: AsyncConnection | None = None ) -> ProductPriceInfo | None: - """newest price of a product or None if not billable""" - async with self.engine.acquire() as conn: + async with pass_or_acquire_connection(self.engine, connection) as conn: return await get_product_latest_price_info_or_none( conn, product_name=product_name ) - async def get_product_stripe_info(self, product_name: str) -> ProductStripeInfoGet: - async with self.engine.acquire() as conn: - row = await get_product_latest_stripe_info(conn, product_name=product_name) - return ProductStripeInfoGet( - stripe_price_id=row[0], stripe_tax_rate_id=row[1] + async def get_product_stripe_info_or_none( + self, product_name: str, connection: AsyncConnection | None = None + ) -> ProductStripeInfo | None: + async with pass_or_acquire_connection(self.engine, connection) as conn: + latest_stripe_info = await get_product_latest_stripe_info_or_none( + conn, product_name=product_name + ) + if latest_stripe_info is None: + return None + + stripe_price_id, stripe_tax_rate_id = latest_stripe_info + return ProductStripeInfo( + stripe_price_id=stripe_price_id, stripe_tax_rate_id=stripe_tax_rate_id ) async def get_template_content( - self, - template_name: str, + self, template_name: str, connection: AsyncConnection | None = None ) -> str | None: - async with self.engine.acquire() as conn: - template_content: str | None = await conn.scalar( - sa.select(jinja2_templates.c.content).where( - jinja2_templates.c.name == template_name - ) - ) + query = sa.select(jinja2_templates.c.content).where( + jinja2_templates.c.name == template_name + ) + + async with pass_or_acquire_connection(self.engine, connection) as conn: + template_content: str | None = await conn.scalar(query) return template_content async def get_product_template_content( self, product_name: str, product_template: sa.Column = products.c.registration_email_template, + connection: AsyncConnection | None = None, ) -> str | None: - async with self.engine.acquire() as conn: - oj = sa.join( - products, - jinja2_templates, - product_template == jinja2_templates.c.name, - isouter=True, - ) - content = await conn.scalar( - sa.select(jinja2_templates.c.content) - .select_from(oj) - .where(products.c.name == product_name) + query = ( + sa.select(jinja2_templates.c.content) + .select_from( + sa.join( + products, + jinja2_templates, + product_template == jinja2_templates.c.name, + isouter=True, + ) ) - return f"{content}" if content else None + .where(products.c.name == product_name) + ) - async def get_product_ui(self, product_name: ProductName) -> dict[str, Any] | None: - async with self.engine.acquire() as conn: - result = await conn.execute( - sa.select(products.c.ui).where(products.c.name == product_name) - ) - row: RowProxy | None = await result.first() + async with pass_or_acquire_connection(self.engine, connection) as conn: + template_content: str | None = await conn.scalar(query) + return template_content + + async def get_product_ui( + self, product_name: ProductName, connection: AsyncConnection | None = None + ) -> dict[str, Any] | None: + query = sa.select(products.c.ui).where(products.c.name == product_name) + + async with pass_or_acquire_connection(self.engine, connection) as conn: + result = await conn.execute(query) + row = result.one_or_none() return dict(**row.ui) if row else None + + async def auto_create_products_groups( + self, + connection: AsyncConnection | None = None, + ) -> dict[ProductName, GroupID]: + product_groups_map: dict[ProductName, GroupID] = {} + + product_names = await self.list_products_names(connection) + for product_name in product_names: + # NOTE: transaction is per product. fail-fast! + async with transaction_context(self.engine, connection) as conn: + product_group_id: GroupID = await get_or_create_product_group( + conn, product_name + ) + product_groups_map[product_name] = product_group_id + + return product_groups_map diff --git a/services/web/server/src/simcore_service_webserver/products/_rpc.py b/services/web/server/src/simcore_service_webserver/products/_rpc.py deleted file mode 100644 index b3e0329f1a8e..000000000000 --- a/services/web/server/src/simcore_service_webserver/products/_rpc.py +++ /dev/null @@ -1,29 +0,0 @@ -from decimal import Decimal - -from aiohttp import web -from models_library.api_schemas_webserver import WEBSERVER_RPC_NAMESPACE -from models_library.products import CreditResultGet, ProductName -from servicelib.rabbitmq import RPCRouter - -from ..rabbitmq import get_rabbitmq_rpc_server -from . import _service - -router = RPCRouter() - - -@router.expose() -async def get_credit_amount( - app: web.Application, - *, - dollar_amount: Decimal, - product_name: ProductName, -) -> CreditResultGet: - credit_result_get: CreditResultGet = await _service.get_credit_amount( - app, dollar_amount=dollar_amount, product_name=product_name - ) - return credit_result_get - - -async def register_rpc_routes_on_startup(app: web.Application): - rpc_server = get_rabbitmq_rpc_server(app) - await rpc_server.register_router(router, WEBSERVER_RPC_NAMESPACE, app) diff --git a/services/web/server/src/simcore_service_webserver/products/_service.py b/services/web/server/src/simcore_service_webserver/products/_service.py index b5c17669e288..032f20d8083e 100644 --- a/services/web/server/src/simcore_service_webserver/products/_service.py +++ b/services/web/server/src/simcore_service_webserver/products/_service.py @@ -1,14 +1,19 @@ from decimal import Decimal -from typing import Any, cast +from typing import Any from aiohttp import web -from models_library.products import CreditResultGet, ProductName, ProductStripeInfoGet +from models_library.groups import GroupID +from models_library.products import ProductName +from pydantic import ValidationError +from servicelib.exceptions import InvalidConfig from simcore_postgres_database.utils_products_prices import ProductPriceInfo from ..constants import APP_PRODUCTS_KEY +from ._models import CreditResult, ProductStripeInfo from ._repository import ProductRepository from .errors import ( BelowMinimumPaymentError, + MissingStripeConfigError, ProductNotFoundError, ProductPriceNotDefinedError, ProductTemplateNotFoundError, @@ -16,9 +21,27 @@ from .models import Product +async def load_products(app: web.Application) -> list[Product]: + repo = ProductRepository.create_from_app(app) + try: + # NOTE: list_products implemented as fails-fast! + return await repo.list_products() + except ValidationError as err: + msg = f"Invalid product configuration in db:\n {err}" + raise InvalidConfig(msg) from err + + +async def get_default_product_name(app: web.Application) -> ProductName: + repo = ProductRepository.create_from_app(app) + return await repo.get_default_product_name() + + def get_product(app: web.Application, product_name: ProductName) -> Product: - product: Product = app[APP_PRODUCTS_KEY][product_name] - return product + try: + product: Product = app[APP_PRODUCTS_KEY][product_name] + return product + except KeyError as exc: + raise ProductNotFoundError(product_name=product_name) from exc def list_products(app: web.Application) -> list[Product]: @@ -36,10 +59,7 @@ async def get_credit_price_info( app: web.Application, product_name: ProductName ) -> ProductPriceInfo | None: repo = ProductRepository.create_from_app(app) - return cast( # mypy: not sure why - ProductPriceInfo | None, - await repo.get_product_latest_price_info_or_none(product_name), - ) + return await repo.get_product_latest_price_info_or_none(product_name) async def get_product_ui( @@ -57,7 +77,7 @@ async def get_credit_amount( *, dollar_amount: Decimal, product_name: ProductName, -) -> CreditResultGet: +) -> CreditResult: """For provided dollars and product gets credit amount. NOTE: Contrary to other product api functions (e.g. get_current_product) this function @@ -85,22 +105,27 @@ async def get_credit_amount( ) credit_amount = dollar_amount / price_info.usd_per_credit - return CreditResultGet(product_name=product_name, credit_amount=credit_amount) + return CreditResult(product_name=product_name, credit_amount=credit_amount) async def get_product_stripe_info( app: web.Application, *, product_name: ProductName -) -> ProductStripeInfoGet: +) -> ProductStripeInfo: repo = ProductRepository.create_from_app(app) - product_stripe_info = await repo.get_product_stripe_info(product_name) + + product_stripe_info = await repo.get_product_stripe_info_or_none(product_name) if ( - not product_stripe_info + product_stripe_info is None or "missing!!" in product_stripe_info.stripe_price_id or "missing!!" in product_stripe_info.stripe_tax_rate_id ): - msg = f"Missing product stripe for product {product_name}" - raise ValueError(msg) - return cast(ProductStripeInfoGet, product_stripe_info) # mypy: not sure why + exc = MissingStripeConfigError( + product_name=product_name, + product_stripe_info=product_stripe_info, + ) + exc.add_note("Probably stripe side is not configured") + raise exc + return product_stripe_info async def get_template_content(app: web.Application, *, template_name: str): @@ -109,3 +134,10 @@ async def get_template_content(app: web.Application, *, template_name: str): if not content: raise ProductTemplateNotFoundError(template_name=template_name) return content + + +async def auto_create_products_groups( + app: web.Application, +) -> dict[ProductName, GroupID]: + repo = ProductRepository.create_from_app(app) + return await repo.auto_create_products_groups() diff --git a/services/web/server/src/simcore_service_webserver/products/_web_events.py b/services/web/server/src/simcore_service_webserver/products/_web_events.py index 5f14cafca884..7000cb21b1e9 100644 --- a/services/web/server/src/simcore_service_webserver/products/_web_events.py +++ b/services/web/server/src/simcore_service_webserver/products/_web_events.py @@ -1,43 +1,21 @@ import logging import tempfile -from collections import OrderedDict from pathlib import Path +from pprint import pformat from aiohttp import web -from aiopg.sa.engine import Engine -from aiopg.sa.result import RowProxy -from pydantic import ValidationError -from servicelib.exceptions import InvalidConfig -from simcore_postgres_database.utils_products import ( - get_default_product_name, - get_or_create_product_group, -) - -from ..constants import APP_PRODUCTS_KEY, FRONTEND_APP_DEFAULT, FRONTEND_APPS_AVAILABLE -from ..db.plugin import get_database_engine -from ._repository import get_product_payment_fields, iter_products -from .models import Product +from models_library.products import ProductName + +from ..constants import APP_PRODUCTS_KEY +from . import _service +from ._models import Product _logger = logging.getLogger(__name__) APP_PRODUCTS_TEMPLATES_DIR_KEY = f"{__name__}.template_dir" -async def setup_product_templates(app: web.Application): - """ - builds a directory and download product templates - """ - with tempfile.TemporaryDirectory( - suffix=APP_PRODUCTS_TEMPLATES_DIR_KEY - ) as templates_dir: - app[APP_PRODUCTS_TEMPLATES_DIR_KEY] = Path(templates_dir) - - yield - - # cleanup - - -async def auto_create_products_groups(app: web.Application) -> None: +async def _auto_create_products_groups(app: web.Application) -> None: """Ensures all products have associated group ids Avoids having undefined groups in products with new products.group_id column @@ -45,63 +23,57 @@ async def auto_create_products_groups(app: web.Application) -> None: NOTE: could not add this in 'setup_groups' (groups plugin) since it has to be executed BEFORE 'load_products_on_startup' """ - engine = get_database_engine(app) - - async with engine.acquire() as connection: - async for row in iter_products(connection): - product_name = row.name # type: ignore[attr-defined] # sqlalchemy - product_group_id = await get_or_create_product_group( - connection, product_name - ) - _logger.debug( - "Product with %s has an associated group with %s", - f"{product_name=}", - f"{product_group_id=}", - ) + product_groups_map = await _service.auto_create_products_groups(app) + _logger.debug("Products group IDs: %s", pformat(product_groups_map)) def _set_app_state( app: web.Application, - app_products: OrderedDict[str, Product], + app_products: dict[ProductName, Product], default_product_name: str, ): + # NOTE: products are checked on every request, therefore we + # cache them in the `app` upon startup app[APP_PRODUCTS_KEY] = app_products assert default_product_name in app_products # nosec app[f"{APP_PRODUCTS_KEY}_default"] = default_product_name -async def load_products_on_startup(app: web.Application): +async def _load_products_on_startup(app: web.Application): """ Loads info on products stored in the database into app's storage (i.e. memory) """ - app_products: OrderedDict[str, Product] = OrderedDict() - engine: Engine = get_database_engine(app) - async with engine.acquire() as connection: - async for row in iter_products(connection): - assert isinstance(row, RowProxy) # nosec - try: - name = row.name + app_products: dict[ProductName, Product] = { + product.name: product for product in await _service.load_products(app) + } + + default_product_name = await _service.get_default_product_name(app) + + _set_app_state(app, app_products, default_product_name) + assert APP_PRODUCTS_KEY in app # nosec - payments = await get_product_payment_fields( - connection, product_name=name - ) + _logger.debug("Product loaded: %s", list(app_products)) - app_products[name] = Product( - **dict(row.items()), - is_payment_enabled=payments.enabled, - credits_per_usd=payments.credits_per_usd, - ) - assert name in FRONTEND_APPS_AVAILABLE # nosec +async def _setup_product_templates(app: web.Application): + """ + builds a directory and download product templates + """ + with tempfile.TemporaryDirectory( + suffix=APP_PRODUCTS_TEMPLATES_DIR_KEY + ) as templates_dir: + app[APP_PRODUCTS_TEMPLATES_DIR_KEY] = Path(templates_dir) - except ValidationError as err: - msg = f"Invalid product configuration in db '{row}':\n {err}" - raise InvalidConfig(msg) from err + yield - assert FRONTEND_APP_DEFAULT in app_products # nosec + # cleanup - default_product_name = await get_default_product_name(connection) - _set_app_state(app, app_products, default_product_name) +def setup_web_events(app: web.Application): - _logger.debug("Product loaded: %s", [p.name for p in app_products.values()]) + app.on_startup.append( + # NOTE: must go BEFORE _load_products_on_startup + _auto_create_products_groups + ) + app.on_startup.append(_load_products_on_startup) + app.cleanup_ctx.append(_setup_product_templates) diff --git a/services/web/server/src/simcore_service_webserver/products/_web_helpers.py b/services/web/server/src/simcore_service_webserver/products/_web_helpers.py index a1990aeb2138..859793d9e0a8 100644 --- a/services/web/server/src/simcore_service_webserver/products/_web_helpers.py +++ b/services/web/server/src/simcore_service_webserver/products/_web_helpers.py @@ -1,9 +1,15 @@ +import contextlib from pathlib import Path import aiofiles from aiohttp import web from models_library.products import ProductName from simcore_postgres_database.utils_products_prices import ProductPriceInfo +from simcore_service_webserver.products.errors import ( + FileTemplateNotFoundError, + ProductNotFoundError, + UnknownProductError, +) from .._resources import webserver_resources from ..constants import RQ_PRODUCT_KEY @@ -14,7 +20,13 @@ def get_product_name(request: web.Request) -> str: """Returns product name in request but might be undefined""" - product_name: str = request[RQ_PRODUCT_KEY] + # NOTE: introduced by middleware + try: + product_name: str = request[RQ_PRODUCT_KEY] + except KeyError as exc: + error = UnknownProductError() + error.add_note("TIP: Check products middleware") + raise error from exc return product_name @@ -27,6 +39,13 @@ def get_current_product(request: web.Request) -> Product: return current_product +def _get_current_product_or_none(request: web.Request) -> Product | None: + with contextlib.suppress(ProductNotFoundError, UnknownProductError): + product: Product = get_current_product(request) + return product + return None + + async def get_current_product_credit_price_info( request: web.Request, ) -> ProductPriceInfo | None: @@ -48,19 +67,10 @@ def _themed(dirname: str, template: str) -> Path: return path -def _get_current_product_or_none(request: web.Request) -> Product | None: - try: - product: Product = get_current_product(request) - return product - except KeyError: - return None - - async def _get_common_template_path(filename: str) -> Path: common_template = _themed("templates/common", filename) if not common_template.exists(): - msg = f"{filename} is not part of the templates/common" - raise ValueError(msg) + raise FileTemplateNotFoundError(filename=filename) return common_template diff --git a/services/web/server/src/simcore_service_webserver/products/errors.py b/services/web/server/src/simcore_service_webserver/products/errors.py index a4a42542f58a..3b0da3564f51 100644 --- a/services/web/server/src/simcore_service_webserver/products/errors.py +++ b/services/web/server/src/simcore_service_webserver/products/errors.py @@ -1,8 +1,11 @@ from ..errors import WebServerBaseError -class ProductError(WebServerBaseError, ValueError): - ... +class ProductError(WebServerBaseError, ValueError): ... + + +class UnknownProductError(ProductError): + msg_template = "Cannot determine which is the product in the current context" class ProductNotFoundError(ProductError): @@ -19,3 +22,14 @@ class BelowMinimumPaymentError(ProductError): class ProductTemplateNotFoundError(ProductError): msg_template = "Missing template {template_name} for product" + + +class MissingStripeConfigError(ProductError): + msg_template = ( + "Missing product stripe for product {product_name}.\n" + "NOTE: This is currently setup manually by the operator in pg database via adminer and also in the stripe platform." + ) + + +class FileTemplateNotFoundError(ProductError): + msg_template = "{filename} is not part of the templates/common" diff --git a/services/web/server/src/simcore_service_webserver/products/models.py b/services/web/server/src/simcore_service_webserver/products/models.py index 29204d49c477..4625012a4843 100644 --- a/services/web/server/src/simcore_service_webserver/products/models.py +++ b/services/web/server/src/simcore_service_webserver/products/models.py @@ -1,9 +1,11 @@ from models_library.products import ProductName -from ._models import Product +from ._models import CreditResult, Product, ProductStripeInfo __all__: tuple[str, ...] = ( + "CreditResult", "Product", "ProductName", + "ProductStripeInfo", ) # nopycln: file diff --git a/services/web/server/src/simcore_service_webserver/products/plugin.py b/services/web/server/src/simcore_service_webserver/products/plugin.py index 062032e4221d..5aea6edcf7e2 100644 --- a/services/web/server/src/simcore_service_webserver/products/plugin.py +++ b/services/web/server/src/simcore_service_webserver/products/plugin.py @@ -29,26 +29,15 @@ def setup_products(app: web.Application): # specially if this plugin is not set up to be loaded # from ..constants import APP_SETTINGS_KEY - from ..rabbitmq import setup_rabbitmq - from . import _rest, _rpc, _web_events, _web_middlewares + from . import _web_events, _web_middlewares + from ._controller import rest, rpc assert app[APP_SETTINGS_KEY].WEBSERVER_PRODUCTS is True # nosec - # set middlewares app.middlewares.append(_web_middlewares.discover_product_middleware) - # setup rest - app.router.add_routes(_rest.routes) - - # setup rpc - setup_rabbitmq(app) - if app[APP_SETTINGS_KEY].WEBSERVER_RABBITMQ: - app.on_startup.append(_rpc.register_rpc_routes_on_startup) - - # setup events - app.on_startup.append( - # NOTE: must go BEFORE load_products_on_startup - _web_events.auto_create_products_groups - ) - app.on_startup.append(_web_events.load_products_on_startup) - app.cleanup_ctx.append(_web_events.setup_product_templates) + app.router.add_routes(rest.routes) + + rpc.setup_rpc(app) + + _web_events.setup_web_events(app) diff --git a/services/web/server/src/simcore_service_webserver/projects/_projects_db.py b/services/web/server/src/simcore_service_webserver/projects/_projects_db.py index 8787699fe832..2754d1124a9c 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_projects_db.py +++ b/services/web/server/src/simcore_service_webserver/projects/_projects_db.py @@ -62,7 +62,7 @@ async def list_trashed_projects( order_by: OrderBy = _OLDEST_TRASHED_FIRST, ) -> tuple[int, list[ProjectDBGet]]: - base_query = sql.select(PROJECT_DB_COLS).where(projects.c.trashed.is_not(None)) + base_query = sql.select(*PROJECT_DB_COLS).where(projects.c.trashed.is_not(None)) if is_set(trashed_explicitly): assert isinstance(trashed_explicitly, bool) # nosec diff --git a/services/web/server/src/simcore_service_webserver/projects/db.py b/services/web/server/src/simcore_service_webserver/projects/db.py index f27b2573dc2c..7d512b484d65 100644 --- a/services/web/server/src/simcore_service_webserver/projects/db.py +++ b/services/web/server/src/simcore_service_webserver/projects/db.py @@ -1,7 +1,7 @@ -""" Database API +"""Database API - - Adds a layer to the postgres API with a focus on the projects data - - Shall be used as entry point for all the queries to the database regarding projects +- Adds a layer to the postgres API with a focus on the projects data +- Shall be used as entry point for all the queries to the database regarding projects """ @@ -36,7 +36,7 @@ from pydantic.types import PositiveInt from servicelib.aiohttp.application_keys import APP_AIOPG_ENGINE_KEY from servicelib.logging_utils import get_log_record_extra, log_context -from simcore_postgres_database.errors import UniqueViolation +from simcore_postgres_database.aiopg_errors import UniqueViolation from simcore_postgres_database.models.groups import user_to_groups from simcore_postgres_database.models.project_to_groups import project_to_groups from simcore_postgres_database.models.projects_nodes import projects_nodes diff --git a/services/web/server/src/simcore_service_webserver/security/_authz_policy.py b/services/web/server/src/simcore_service_webserver/security/_authz_policy.py index 612c1e649756..3bd5408f4d3c 100644 --- a/services/web/server/src/simcore_service_webserver/security/_authz_policy.py +++ b/services/web/server/src/simcore_service_webserver/security/_authz_policy.py @@ -1,6 +1,4 @@ -""" AUTHoriZation (auth) policy - -""" +"""AUTHoriZation (auth) policy""" import contextlib import logging @@ -14,7 +12,7 @@ ) from models_library.products import ProductName from models_library.users import UserID -from simcore_postgres_database.errors import DatabaseError +from simcore_postgres_database.aiopg_errors import DatabaseError from ..db.plugin import get_database_engine from ._authz_access_model import ( diff --git a/services/web/server/src/simcore_service_webserver/wallets/_payments_handlers.py b/services/web/server/src/simcore_service_webserver/wallets/_payments_handlers.py index e0b4baddd91c..2751abc457e1 100644 --- a/services/web/server/src/simcore_service_webserver/wallets/_payments_handlers.py +++ b/services/web/server/src/simcore_service_webserver/wallets/_payments_handlers.py @@ -12,7 +12,6 @@ ReplaceWalletAutoRecharge, WalletPaymentInitiated, ) -from models_library.products import CreditResultGet from models_library.rest_pagination import Page, PageQueryParameters from models_library.rest_pagination_utils import paginate_data from servicelib.aiohttp import status @@ -24,6 +23,7 @@ ) from servicelib.logging_utils import get_log_record_extra, log_context from servicelib.utils import fire_and_forget_task +from simcore_service_webserver.products._models import CreditResult from .._meta import API_VTAG as VTAG from ..login.decorators import login_required @@ -79,7 +79,7 @@ async def _create_payment(request: web.Request): log_duration=True, extra=get_log_record_extra(user_id=req_ctx.user_id), ): - credit_result: CreditResultGet = await products_service.get_credit_amount( + credit_result: CreditResult = await products_service.get_credit_amount( request.app, dollar_amount=body_params.price_dollars, product_name=req_ctx.product_name, @@ -351,7 +351,7 @@ async def _pay_with_payment_method(request: web.Request): log_duration=True, extra=get_log_record_extra(user_id=req_ctx.user_id), ): - credit_result: CreditResultGet = await products_service.get_credit_amount( + credit_result: CreditResult = await products_service.get_credit_amount( request.app, dollar_amount=body_params.price_dollars, product_name=req_ctx.product_name, diff --git a/services/web/server/tests/unit/conftest.py b/services/web/server/tests/unit/conftest.py index b322655c20cc..4c6dd952f46a 100644 --- a/services/web/server/tests/unit/conftest.py +++ b/services/web/server/tests/unit/conftest.py @@ -10,10 +10,10 @@ from collections.abc import Callable, Iterable from pathlib import Path from typing import Any -from unittest.mock import MagicMock import pytest import yaml +from pytest_mock import MockFixture, MockType from pytest_simcore.helpers.webserver_projects import empty_project_data from simcore_service_webserver.application_settings_utils import AppConfigDict @@ -62,7 +62,7 @@ def activity_data(fake_data_dir: Path) -> Iterable[dict[str, Any]]: @pytest.fixture -def mock_orphaned_services(mocker) -> MagicMock: +def mock_orphaned_services(mocker: MockFixture) -> MockType: return mocker.patch( "simcore_service_webserver.garbage_collector._core.remove_orphaned_services", return_value="", @@ -70,9 +70,19 @@ def mock_orphaned_services(mocker) -> MagicMock: @pytest.fixture -def disable_gc_manual_guest_users(mocker): +def disable_gc_manual_guest_users(mocker: MockFixture) -> None: """Disable to avoid an almost instant cleanup of GUEST users with their projects""" mocker.patch( "simcore_service_webserver.garbage_collector._core.remove_users_manually_marked_as_guests", return_value=None, ) + + +@pytest.fixture +def disabled_setup_garbage_collector(mocker: MockFixture) -> MockType: + # WARNING: add it BEFORE `client` to have effect + return mocker.patch( + "simcore_service_webserver.application.setup_garbage_collector", + autospec=True, + return_value=False, + ) diff --git a/services/web/server/tests/unit/with_dbs/01/test_api_keys.py b/services/web/server/tests/unit/with_dbs/01/test_api_keys.py index aa6d4621209b..cfd7b9f154a6 100644 --- a/services/web/server/tests/unit/with_dbs/01/test_api_keys.py +++ b/services/web/server/tests/unit/with_dbs/01/test_api_keys.py @@ -3,6 +3,7 @@ # pylint: disable=unused-variable # pylint: disable=too-many-arguments +import asyncio from collections.abc import AsyncIterable from datetime import timedelta from http import HTTPStatus @@ -13,24 +14,21 @@ from aiohttp.test_utils import TestClient from faker import Faker from models_library.products import ProductName -from pytest_mock import MockerFixture +from pytest_mock import MockerFixture, MockType from pytest_simcore.helpers.assert_checks import assert_status from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict from pytest_simcore.helpers.typing_env import EnvVarsDict from pytest_simcore.helpers.webserver_login import NewUser, UserInfoDict from servicelib.aiohttp import status -from servicelib.aiohttp.application_keys import APP_SETTINGS_KEY -from simcore_service_webserver.api_keys import _repository as repo -from simcore_service_webserver.api_keys._models import ApiKey -from simcore_service_webserver.api_keys._service import ( - get_or_create_api_key, - prune_expired_api_keys, +from simcore_service_webserver.api_keys import _repository, _service, api_keys_service +from simcore_service_webserver.api_keys.models import ApiKey +from simcore_service_webserver.application_settings import ( + ApplicationSettings, + get_application_settings, ) -from simcore_service_webserver.application_settings import GarbageCollectorSettings from simcore_service_webserver.db.models import UserRole from tenacity import ( retry_if_exception_type, - stop_after_attempt, stop_after_delay, wait_fixed, ) @@ -42,10 +40,11 @@ async def fake_user_api_keys( logged_user: UserInfoDict, osparc_product_name: ProductName, faker: Faker, -) -> AsyncIterable[list[int]]: +) -> AsyncIterable[list[ApiKey]]: assert client.app + api_keys: list[ApiKey] = [ - await repo.create_api_key( + await _repository.create_api_key( client.app, user_id=logged_user["id"], product_name=osparc_product_name, @@ -60,7 +59,7 @@ async def fake_user_api_keys( yield api_keys for api_key in api_keys: - await repo.delete_api_key( + await _repository.delete_api_key( client.app, api_key_id=api_key.id, user_id=logged_user["id"], @@ -85,11 +84,11 @@ def _get_user_access_parametrizations(expected_authed_status_code): _get_user_access_parametrizations(status.HTTP_200_OK), ) async def test_list_api_keys( + disabled_setup_garbage_collector: MockType, client: TestClient, logged_user: UserInfoDict, user_role: UserRole, expected: HTTPStatus, - disable_gc_manual_guest_users: None, ): resp = await client.get("/v0/auth/api-keys") data, errors = await assert_status(resp, expected) @@ -103,11 +102,11 @@ async def test_list_api_keys( _get_user_access_parametrizations(status.HTTP_200_OK), ) async def test_create_api_key( + disabled_setup_garbage_collector: MockType, client: TestClient, logged_user: UserInfoDict, user_role: UserRole, expected: HTTPStatus, - disable_gc_manual_guest_users: None, ): display_name = "foo" resp = await client.post("/v0/auth/api-keys", json={"displayName": display_name}) @@ -129,12 +128,12 @@ async def test_create_api_key( _get_user_access_parametrizations(status.HTTP_204_NO_CONTENT), ) async def test_delete_api_keys( + disabled_setup_garbage_collector: MockType, client: TestClient, fake_user_api_keys: list[ApiKey], logged_user: UserInfoDict, user_role: UserRole, expected: HTTPStatus, - disable_gc_manual_guest_users: None, ): resp = await client.delete("/v0/auth/api-keys/0") await assert_status(resp, expected) @@ -144,47 +143,58 @@ async def test_delete_api_keys( await assert_status(resp, expected) +EXPIRATION_WAIT_FACTOR = 1.2 + + @pytest.mark.parametrize( "user_role,expected", _get_user_access_parametrizations(status.HTTP_200_OK), ) async def test_create_api_key_with_expiration( + disabled_setup_garbage_collector: MockType, client: TestClient, logged_user: UserInfoDict, user_role: UserRole, expected: HTTPStatus, - disable_gc_manual_guest_users: None, + mocker: MockerFixture, ): assert client.app + # test gc is actually disabled + gc_prune_mock = mocker.patch( + "simcore_service_webserver.garbage_collector._tasks_api_keys.create_background_task_to_prune_api_keys", + spec=True, + ) + assert not gc_prune_mock.called + + expected_api_key = "foo" + # create api-keys with expiration interval expiration_interval = timedelta(seconds=1) resp = await client.post( "/v0/auth/api-keys", - json={"displayName": "foo", "expiration": expiration_interval.seconds}, + json={ + "displayName": expected_api_key, + "expiration": expiration_interval.seconds, + }, ) data, errors = await assert_status(resp, expected) if not errors: - assert data["displayName"] == "foo" + assert data["displayName"] == expected_api_key assert "apiKey" in data assert "apiSecret" in data # list created api-key resp = await client.get("/v0/auth/api-keys") data, _ = await assert_status(resp, expected) - assert [d["displayName"] for d in data] == ["foo"] + assert [d["displayName"] for d in data] == [expected_api_key] # wait for api-key for it to expire and force-run scheduled task - async for attempt in tenacity.AsyncRetrying( - wait=wait_fixed(1), - retry=retry_if_exception_type(AssertionError), - stop=stop_after_delay(5 * expiration_interval.seconds), - reraise=True, - ): - with attempt: - deleted = await prune_expired_api_keys(client.app) - assert deleted == ["foo"] + await asyncio.sleep(EXPIRATION_WAIT_FACTOR * expiration_interval.seconds) + + deleted = await api_keys_service.prune_expired_api_keys(client.app) + assert deleted == [expected_api_key] resp = await client.get("/v0/auth/api-keys") data, _ = await assert_status(resp, expected) @@ -192,6 +202,7 @@ async def test_create_api_key_with_expiration( async def test_get_or_create_api_key( + disabled_setup_garbage_collector: MockType, client: TestClient, ): async with NewUser( @@ -207,13 +218,15 @@ async def test_get_or_create_api_key( } # create once - created = await get_or_create_api_key(client.app, **options) + created = await _service.get_or_create_api_key(client.app, **options) assert created.display_name == "foo" assert created.api_key != created.api_secret # idempotent for _ in range(3): - assert await get_or_create_api_key(client.app, **options) == created + assert ( + await _service.get_or_create_api_key(client.app, **options) == created + ) @pytest.mark.parametrize( @@ -221,11 +234,11 @@ async def test_get_or_create_api_key( _get_user_access_parametrizations(status.HTTP_404_NOT_FOUND), ) async def test_get_not_existing_api_key( + disabled_setup_garbage_collector: MockType, client: TestClient, logged_user: UserInfoDict, user_role: UserRole, expected: HTTPException, - disable_gc_manual_guest_users: None, ): resp = await client.get("/v0/auth/api-keys/42") data, errors = await assert_status(resp, expected) @@ -248,20 +261,31 @@ async def app_environment( async def test_prune_expired_api_keys_task_is_triggered( - app_environment: EnvVarsDict, mocker: MockerFixture, client: TestClient + app_environment: EnvVarsDict, + mocker: MockerFixture, + client: TestClient, ): - mock = mocker.patch( - "simcore_service_webserver.api_keys._service._repository.prune_expired" - ) - settings = client.server.app[ # type: ignore - APP_SETTINGS_KEY - ].WEBSERVER_GARBAGE_COLLECTOR - assert isinstance(settings, GarbageCollectorSettings) + assert app_environment["WEBSERVER_GARBAGE_COLLECTOR"] is not None + + delete_expired_spy = mocker.spy(_repository, "delete_expired_api_keys") + + assert client.app + + settings: ApplicationSettings = get_application_settings(client.app) + assert settings.WEBSERVER_GARBAGE_COLLECTOR + + assert not delete_expired_spy.called + async for attempt in tenacity.AsyncRetrying( - stop=stop_after_attempt(5), + stop=stop_after_delay( + timedelta( + seconds=EXPIRATION_WAIT_FACTOR + * settings.WEBSERVER_GARBAGE_COLLECTOR.GARBAGE_COLLECTOR_EXPIRED_USERS_CHECK_INTERVAL_S + ) + ), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError), reraise=True, ): with attempt: - mock.assert_called() + delete_expired_spy.assert_called() diff --git a/services/web/server/tests/unit/with_dbs/01/test_api_keys_rpc.py b/services/web/server/tests/unit/with_dbs/01/test_api_keys_rpc.py index aa45fd9fd2ec..3053df387976 100644 --- a/services/web/server/tests/unit/with_dbs/01/test_api_keys_rpc.py +++ b/services/web/server/tests/unit/with_dbs/01/test_api_keys_rpc.py @@ -23,8 +23,8 @@ from settings_library.rabbit import RabbitSettings from simcore_postgres_database.models.users import UserRole from simcore_service_webserver.api_keys import _repository as repo -from simcore_service_webserver.api_keys._models import ApiKey from simcore_service_webserver.api_keys.errors import ApiKeyNotFoundError +from simcore_service_webserver.api_keys.models import ApiKey from simcore_service_webserver.application_settings import ApplicationSettings pytest_simcore_core_services_selection = [ diff --git a/services/web/server/tests/unit/with_dbs/03/invitations/test_login_handlers_registration_invitations.py b/services/web/server/tests/unit/with_dbs/03/invitations/test_login_handlers_registration_invitations.py index 8c4daca29df0..7a081e39cb64 100644 --- a/services/web/server/tests/unit/with_dbs/03/invitations/test_login_handlers_registration_invitations.py +++ b/services/web/server/tests/unit/with_dbs/03/invitations/test_login_handlers_registration_invitations.py @@ -113,7 +113,7 @@ def _extract_invitation_code_from_url(invitation_url: HttpUrl) -> str: @pytest.mark.acceptance_test() async def test_registration_to_different_product( mocker: MockerFixture, - all_products_names: list[ProductName], + app_products_names: list[ProductName], client: TestClient, guest_email: str, guest_password: str, @@ -146,8 +146,8 @@ async def _register_account(invitation_url: HttpUrl, product_deployed: ProductNa headers={X_PRODUCT_NAME_HEADER: product_deployed}, ) - product_a = all_products_names[0] - product_b = all_products_names[1] + product_a = app_products_names[0] + product_b = app_products_names[1] # PO creates an two invitations for guest in product A and product B invitation_product_a = await generate_invitation( diff --git a/services/web/server/tests/unit/with_dbs/03/invitations/test_products_rest_invitations.py b/services/web/server/tests/unit/with_dbs/03/invitations/test_products_rest_invitations.py index 473ba0c33c68..f384bbe46fbe 100644 --- a/services/web/server/tests/unit/with_dbs/03/invitations/test_products_rest_invitations.py +++ b/services/web/server/tests/unit/with_dbs/03/invitations/test_products_rest_invitations.py @@ -11,7 +11,7 @@ import pytest from aiohttp.test_utils import TestClient from faker import Faker -from models_library.api_schemas_webserver.product import ( +from models_library.api_schemas_webserver.products import ( InvitationGenerate, InvitationGenerated, ) diff --git a/services/web/server/tests/unit/with_dbs/03/login/test_login_2fa.py b/services/web/server/tests/unit/with_dbs/03/login/test_login_2fa.py index 62cddd9b6882..242bb75393c8 100644 --- a/services/web/server/tests/unit/with_dbs/03/login/test_login_2fa.py +++ b/services/web/server/tests/unit/with_dbs/03/login/test_login_2fa.py @@ -36,6 +36,7 @@ ) from simcore_service_webserver.login.storage import AsyncpgStorage from simcore_service_webserver.products import products_web +from simcore_service_webserver.products.errors import UnknownProductError from simcore_service_webserver.products.models import Product from simcore_service_webserver.users import preferences_api as user_preferences_api from twilio.base.exceptions import TwilioRestException @@ -370,7 +371,7 @@ async def test_send_email_code( ): request = make_mocked_request("GET", "/dummy", app=client.app) - with pytest.raises(KeyError): + with pytest.raises(UnknownProductError): # NOTE: this is a fake request and did not go through middlewares products_web.get_current_product(request) @@ -418,9 +419,9 @@ async def test_2fa_sms_failure_during_login( ): assert client.app - # Mocks error in graylog https://monitoring.osparc.io/graylog/search/649e7619ce6e0838a96e9bf1?q=%222FA%22&rangetype=relative&from=172800 mocker.patch( - "simcore_service_webserver.login._2fa_api.TwilioSettings.is_alphanumeric_supported", + # MD: Emulates error in graylog https://monitoring.osparc.io/graylog/search/649e7619ce6e0838a96e9bf1?q=%222FA%22&rangetype=relative&from=172800 + "simcore_service_webserver.login._2fa_api.twilio.rest.Client", autospec=True, side_effect=TwilioRestException( status=400, diff --git a/services/web/server/tests/unit/with_dbs/04/products/test_products_repository.py b/services/web/server/tests/unit/with_dbs/04/products/test_products_repository.py index 963ffaf2ea55..ed4550eee6d3 100644 --- a/services/web/server/tests/unit/with_dbs/04/products/test_products_repository.py +++ b/services/web/server/tests/unit/with_dbs/04/products/test_products_repository.py @@ -8,12 +8,11 @@ from decimal import Decimal from typing import Any -import aiopg.sa import pytest import sqlalchemy as sa from aiohttp import web from aiohttp.test_utils import TestClient, make_mocked_request -from models_library.products import ProductName, ProductStripeInfoGet +from models_library.products import ProductName from pytest_simcore.helpers.faker_factories import random_product, random_product_price from pytest_simcore.helpers.postgres_tools import sync_insert_and_get_row_lifespan from simcore_postgres_database import utils_products @@ -36,6 +35,7 @@ from simcore_service_webserver.products._web_middlewares import ( _get_default_product_name, ) +from sqlalchemy.ext.asyncio import AsyncEngine @pytest.fixture(scope="module") @@ -186,10 +186,10 @@ async def product_repository(app: web.Application) -> ProductRepository: async def test_utils_products_and_webserver_default_product_in_sync( app: web.Application, product_repository: ProductRepository, - aiopg_engine: aiopg.sa.engine.Engine, + asyncpg_engine: AsyncEngine, ): # tests definitions of default from utle_products and web-server.products are in sync - async with aiopg_engine.acquire() as conn: + async with asyncpg_engine.connect() as conn: default_product_name = await utils_products.get_default_product_name(conn) assert default_product_name == _get_default_product_name(app) @@ -232,12 +232,12 @@ async def test_product_repository_get_product_stripe_info( product_repository: ProductRepository, ): product_name = "tis" - stripe_info = await product_repository.get_product_stripe_info(product_name) - assert isinstance(stripe_info, ProductStripeInfoGet) + stripe_info = await product_repository.get_product_stripe_info_or_none(product_name) + assert stripe_info product_name = "s4l" - with pytest.raises(ValueError, match=product_name): - stripe_info = await product_repository.get_product_stripe_info(product_name) + stripe_info = await product_repository.get_product_stripe_info_or_none(product_name) + assert stripe_info is None async def test_product_repository_get_template_content( diff --git a/services/web/server/tests/unit/with_dbs/04/products/test_products_rest.py b/services/web/server/tests/unit/with_dbs/04/products/test_products_rest.py index 3587285742d9..f9a047ef50eb 100644 --- a/services/web/server/tests/unit/with_dbs/04/products/test_products_rest.py +++ b/services/web/server/tests/unit/with_dbs/04/products/test_products_rest.py @@ -10,7 +10,7 @@ import pytest from aiohttp.test_utils import TestClient -from models_library.api_schemas_webserver.product import ProductGet, ProductUIGet +from models_library.api_schemas_webserver.products import ProductGet, ProductUIGet from models_library.products import ProductName from pytest_simcore.helpers.assert_checks import assert_status from pytest_simcore.helpers.webserver_login import UserInfoDict @@ -128,7 +128,7 @@ async def test_get_product( ], ) async def test_get_current_product_ui( - all_products_names: list[ProductName], + app_products_names: list[ProductName], product_name: ProductName, logged_user: UserInfoDict, client: TestClient, @@ -136,7 +136,7 @@ async def test_get_current_product_ui( expected_status_code: int, ): assert logged_user["role"] == user_role.value - assert product_name in all_products_names + assert product_name in app_products_names # give access to user to this product assert client.app diff --git a/services/web/server/tests/unit/with_dbs/04/products/test_products_rpc.py b/services/web/server/tests/unit/with_dbs/04/products/test_products_rpc.py index 4505a6f4e3ee..08763afefa21 100644 --- a/services/web/server/tests/unit/with_dbs/04/products/test_products_rpc.py +++ b/services/web/server/tests/unit/with_dbs/04/products/test_products_rpc.py @@ -8,7 +8,8 @@ import pytest from models_library.api_schemas_webserver import WEBSERVER_RPC_NAMESPACE -from models_library.products import CreditResultGet, ProductName +from models_library.api_schemas_webserver.products import CreditResultRpcGet +from models_library.products import ProductName from models_library.rabbitmq_basic_types import RPCMethodName from pydantic import TypeAdapter from pytest_mock import MockerFixture @@ -74,7 +75,7 @@ async def test_get_credit_amount( dollar_amount=Decimal(900), product_name="s4l", ) - credit_result = CreditResultGet.model_validate(result) + credit_result = CreditResultRpcGet.model_validate(result) assert credit_result.credit_amount == 100 result = await rpc_client.request( @@ -83,7 +84,7 @@ async def test_get_credit_amount( dollar_amount=Decimal(900), product_name="tis", ) - credit_result = CreditResultGet.model_validate(result) + credit_result = CreditResultRpcGet.model_validate(result) assert credit_result.credit_amount == 180 with pytest.raises(RPCServerError) as exc_info: diff --git a/services/web/server/tests/unit/with_dbs/04/products/test_products_service.py b/services/web/server/tests/unit/with_dbs/04/products/test_products_service.py index 4309c1f2aa8d..3f30f84b9293 100644 --- a/services/web/server/tests/unit/with_dbs/04/products/test_products_service.py +++ b/services/web/server/tests/unit/with_dbs/04/products/test_products_service.py @@ -4,13 +4,27 @@ # pylint: disable=too-many-arguments +from decimal import Decimal + import pytest from aiohttp import web from aiohttp.test_utils import TestServer from models_library.products import ProductName -from simcore_service_webserver.products import products_service +from pydantic import TypeAdapter, ValidationError +from pytest_mock import MockerFixture +from servicelib.exceptions import InvalidConfig +from simcore_postgres_database.utils_products_prices import ProductPriceInfo +from simcore_service_webserver.products import _service, products_service +from simcore_service_webserver.products._models import ProductStripeInfo from simcore_service_webserver.products._repository import ProductRepository -from simcore_service_webserver.products.errors import ProductPriceNotDefinedError +from simcore_service_webserver.products.errors import ( + BelowMinimumPaymentError, + MissingStripeConfigError, + ProductNotFoundError, + ProductPriceNotDefinedError, + ProductTemplateNotFoundError, +) +from simcore_service_webserver.products.models import Product @pytest.fixture @@ -22,8 +36,32 @@ def app( return web_server.app -async def test_get_product(app: web.Application, default_product_name: ProductName): +async def test_load_products(app: web.Application): + products = await _service.load_products(app) + assert isinstance(products, list) + assert all(isinstance(product, Product) for product in products) + + +async def test_load_products_validation_error(app: web.Application, mocker): + mock_repo = mocker.patch( + "simcore_service_webserver.products._service.ProductRepository.create_from_app" + ) + + try: + TypeAdapter(int).validate_python("not-an-int") + except ValidationError as validation_error: + mock_repo.return_value.list_products.side_effect = validation_error + + with pytest.raises(InvalidConfig, match="Invalid product configuration in db"): + await _service.load_products(app) + +async def test_get_default_product_name(app: web.Application): + default_product_name = await _service.get_default_product_name(app) + assert isinstance(default_product_name, ProductName) + + +async def test_get_product(app: web.Application, default_product_name: ProductName): product = products_service.get_product(app, product_name=default_product_name) assert product.name == default_product_name @@ -32,33 +70,127 @@ async def test_get_product(app: web.Application, default_product_name: ProductNa assert products[0] == product -async def test_get_product_ui(app: web.Application, default_product_name: ProductName): - # this feature is currently setup from adminer by an operator +async def test_products_on_uninitialized_app(default_product_name: ProductName): + uninit_app = web.Application() + with pytest.raises(ProductNotFoundError): + _service.get_product(uninit_app, default_product_name) + + +async def test_list_products_names(app: web.Application): + product_names = await products_service.list_products_names(app) + assert isinstance(product_names, list) + assert all(isinstance(name, ProductName) for name in product_names) + +async def test_get_credit_price_info( + app: web.Application, default_product_name: ProductName +): + price_info = await _service.get_credit_price_info( + app, product_name=default_product_name + ) + assert price_info is None or isinstance(price_info, ProductPriceInfo) + + +async def test_get_product_ui(app: web.Application, default_product_name: ProductName): repo = ProductRepository.create_from_app(app) ui = await products_service.get_product_ui(repo, product_name=default_product_name) assert ui == {}, "Expected empty by default" + with pytest.raises(ProductNotFoundError): + await products_service.get_product_ui(repo, product_name="undefined") + + +async def test_get_credit_amount( + app: web.Application, default_product_name: ProductName, mocker: MockerFixture +): + # Test when ProductPriceNotDefinedError is raised + with pytest.raises(ProductPriceNotDefinedError): + await products_service.get_credit_amount( + app, dollar_amount=1, product_name=default_product_name + ) + + +async def test_get_credit_amount_with_repo_faking_data( + default_product_name: ProductName, mocker: MockerFixture +): + # NO need of database since repo is mocked + app = web.Application() + + # Mock the repository to return a valid price info + mock_repo = mocker.patch( + "simcore_service_webserver.products._service.ProductRepository.create_from_app" + ) + + async def _get_product_latest_price_info_or_none(*args, **kwargs): + return ProductPriceInfo( + usd_per_credit=Decimal("10.0"), min_payment_amount_usd=Decimal("5.0") + ) + + mock_repo.return_value.get_product_latest_price_info_or_none.side_effect = ( + _get_product_latest_price_info_or_none + ) + + # Test when BelowMinimumPaymentError is raised + with pytest.raises(BelowMinimumPaymentError): + await products_service.get_credit_amount( + app, dollar_amount=Decimal("3.0"), product_name=default_product_name + ) + + # Test when CreditResultGet is returned successfully + credit_result = await products_service.get_credit_amount( + app, dollar_amount=Decimal("10.0"), product_name=default_product_name + ) + assert credit_result.credit_amount == Decimal("1.0") + assert credit_result.product_name == default_product_name + async def test_get_product_stripe_info( app: web.Application, default_product_name: ProductName ): - # this feature is currently setup from adminer by an operator - - # default is not configured - with pytest.raises(ValueError, match=default_product_name): + # database has no info + with pytest.raises(MissingStripeConfigError, match=default_product_name): await products_service.get_product_stripe_info( app, product_name=default_product_name ) -async def test_get_credit_amount( - app: web.Application, default_product_name: ProductName +async def test_get_product_stripe_info_with_repo_faking_data( + default_product_name: ProductName, mocker: MockerFixture ): - # this feature is currently setup from adminer by an operator + # NO need of database since repo is mocked + app = web.Application() - # default is not configured - with pytest.raises(ProductPriceNotDefinedError): - await products_service.get_credit_amount( - app, dollar_amount=1, product_name=default_product_name - ) + # Mock the repository to return a valid stripe info + mock_repo = mocker.patch( + "simcore_service_webserver.products._service.ProductRepository.create_from_app" + ) + + # Test when stripe info is returned successfully + expected_stripe_info = ProductStripeInfo( + stripe_price_id="price_id", stripe_tax_rate_id="tax_id" + ) + + async def _mock(*args, **kw): + return expected_stripe_info + + mock_repo.return_value.get_product_stripe_info_or_none.side_effect = _mock + + stripe_info = await products_service.get_product_stripe_info( + app, product_name=default_product_name + ) + assert stripe_info == expected_stripe_info + + +async def test_get_template_content(app: web.Application): + template_name = "some_template" + with pytest.raises(ProductTemplateNotFoundError): + await _service.get_template_content(app, template_name=template_name) + + +async def test_auto_create_products_groups(app: web.Application): + groups = await _service.auto_create_products_groups(app) + assert isinstance(groups, dict) + + assert all( + group_id is not None for group_id in groups.values() + ), f"Invalid {groups}" diff --git a/services/web/server/tests/unit/with_dbs/conftest.py b/services/web/server/tests/unit/with_dbs/conftest.py index 10a6e91c5ad4..25ca1f218e01 100644 --- a/services/web/server/tests/unit/with_dbs/conftest.py +++ b/services/web/server/tests/unit/with_dbs/conftest.py @@ -1,20 +1,13 @@ -"""Configuration for unit testing with a postgress fixture - -- Unit testing of webserver app with a postgress service as fixture -- Starts test session by running a postgres container as a fixture (see postgress_service) - -IMPORTANT: remember that these are still unit-tests! -""" - -# nopycln: file # pylint: disable=redefined-outer-name # pylint: disable=unused-argument # pylint: disable=unused-variable +# pylint: disable=too-many-arguments import asyncio import random import sys import textwrap +import warnings from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator from copy import deepcopy from decimal import Decimal @@ -35,7 +28,6 @@ from aiohttp import web from aiohttp.test_utils import TestClient, TestServer from aiopg.sa import create_engine -from aiopg.sa.connection import SAConnection from faker import Faker from models_library.api_schemas_directorv2.dynamic_services import DynamicServiceGet from models_library.products import ProductName @@ -76,6 +68,7 @@ FRONTEND_APPS_AVAILABLE, ) from sqlalchemy import exc as sql_exceptions +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine CURRENT_DIR = Path(sys.argv[0] if __name__ == "__main__" else __file__).resolve().parent @@ -534,6 +527,13 @@ async def aiopg_engine(postgres_db: sa.engine.Engine) -> AsyncIterator[aiopg.sa. engine = await create_engine(f"{postgres_db.url}") assert engine + warnings.warn( + "The 'aiopg_engine' fixture is deprecated and will be removed in a future release. " + "Please use 'asyncpg_engine' fixture instead.", + DeprecationWarning, + stacklevel=2, + ) + yield engine if engine: @@ -541,6 +541,34 @@ async def aiopg_engine(postgres_db: sa.engine.Engine) -> AsyncIterator[aiopg.sa. await engine.wait_closed() +@pytest.fixture +async def asyncpg_engine( # <-- WE SHOULD USE THIS ONE instead of aiopg_engine + postgres_db: sa.engine.Engine, is_pdb_enabled: bool +) -> AsyncIterable[AsyncEngine]: + # NOTE: call to postgres BEFORE app starts + dsn = f"{postgres_db.url}".replace("postgresql://", "postgresql+asyncpg://") + minsize = 1 + maxsize = 50 + + engine: AsyncEngine = create_async_engine( + dsn, + pool_size=minsize, + max_overflow=maxsize - minsize, + connect_args={ + "server_settings": { + "application_name": "webserver_tests_with_dbs:asyncpg_engine" + } + }, + pool_pre_ping=True, # https://docs.sqlalchemy.org/en/14/core/pooling.html#dealing-with-disconnects + future=True, # this uses sqlalchemy 2.0 API, shall be removed when sqlalchemy 2.0 is released + echo=is_pdb_enabled, + ) + + yield engine + + await engine.dispose() + + # REDIS CORE SERVICE ------------------------------------------------------ def _is_redis_responsive(host: str, port: int, password: str) -> bool: # username via https://stackoverflow.com/a/78236235 @@ -678,23 +706,13 @@ async def with_permitted_override_services_specifications( @pytest.fixture -async def _pre_connection(postgres_db: sa.engine.Engine) -> AsyncIterable[SAConnection]: - # NOTE: call to postgres BEFORE app starts - async with await create_engine( - f"{postgres_db.url}" - ) as engine, engine.acquire() as conn: - yield conn - - -@pytest.fixture -async def all_products_names( - _pre_connection: SAConnection, +async def app_products_names( + asyncpg_engine: AsyncEngine, ) -> AsyncIterable[list[ProductName]]: - # default product - result = await _pre_connection.execute( - products.select().order_by(products.c.priority) - ) - rows = await result.fetchall() + async with asyncpg_engine.connect() as conn: + # default product + result = await conn.execute(products.select().order_by(products.c.priority)) + rows = result.fetchall() assert rows assert len(rows) == 1 osparc_product_row = rows[0] @@ -705,37 +723,41 @@ async def all_products_names( priority = 1 for name in FRONTEND_APPS_AVAILABLE: if name != FRONTEND_APP_DEFAULT: - result = await _pre_connection.execute( - products.insert().values( - random_product( - name=name, - priority=priority, - login_settings=osparc_product_row.login_settings, - group_id=None, + + async with asyncpg_engine.begin() as conn: + result = await conn.execute( + products.insert().values( + random_product( + name=name, + priority=priority, + login_settings=osparc_product_row.login_settings, + group_id=None, + ) ) ) - ) - await get_or_create_product_group(_pre_connection, product_name=name) + await get_or_create_product_group(conn, product_name=name) priority += 1 - # get all products - result = await _pre_connection.execute( - sa.select(products.c.name).order_by(products.c.priority) - ) - rows = await result.fetchall() + async with asyncpg_engine.connect() as conn: + # get all products + result = await conn.execute( + sa.select(products.c.name).order_by(products.c.priority) + ) + rows = result.fetchall() yield [r.name for r in rows] - await _pre_connection.execute(products_prices.delete()) - await _pre_connection.execute( - products.delete().where(products.c.name != FRONTEND_APP_DEFAULT) - ) + async with asyncpg_engine.begin() as conn: + await conn.execute(products_prices.delete()) + await conn.execute( + products.delete().where(products.c.name != FRONTEND_APP_DEFAULT) + ) @pytest.fixture async def all_product_prices( - _pre_connection: SAConnection, - all_products_names: list[ProductName], + asyncpg_engine: AsyncEngine, + app_products_names: list[ProductName], faker: Faker, ) -> dict[ProductName, Decimal | None]: """Initial list of prices for all products""" @@ -747,23 +769,24 @@ async def all_product_prices( "tiplite": Decimal(5), "s4l": Decimal(9), "s4llite": Decimal(0), # free of charge - "s4lacad": Decimal(1.1), + "s4lacad": Decimal("1.1"), } result = {} - for product_name in all_products_names: + for product_name in app_products_names: usd_or_none = product_price.get(product_name) if usd_or_none is not None: - await _pre_connection.execute( - products_prices.insert().values( - product_name=product_name, - usd_per_credit=usd_or_none, - comment=faker.sentence(), - min_payment_amount_usd=10, - stripe_price_id=faker.pystr(), - stripe_tax_rate_id=faker.pystr(), + async with asyncpg_engine.begin() as conn: + await conn.execute( + products_prices.insert().values( + product_name=product_name, + usd_per_credit=usd_or_none, + comment=faker.sentence(), + min_payment_amount_usd=10, + stripe_price_id=faker.pystr(), + stripe_tax_rate_id=faker.pystr(), + ) ) - ) result[product_name] = usd_or_none @@ -773,23 +796,23 @@ async def all_product_prices( @pytest.fixture async def latest_osparc_price( all_product_prices: dict[ProductName, Decimal], - _pre_connection: SAConnection, + asyncpg_engine: AsyncEngine, ) -> Decimal: """This inserts a new price for osparc in the history (i.e. the old price of osparc is still in the database) """ - - usd = await _pre_connection.scalar( - products_prices.insert() - .values( - product_name="osparc", - usd_per_credit=all_product_prices["osparc"] + 5, - comment="New price for osparc", - stripe_price_id="stripe-price-id", - stripe_tax_rate_id="stripe-tax-rate-id", + async with asyncpg_engine.begin() as conn: + usd = await conn.scalar( + products_prices.insert() + .values( + product_name="osparc", + usd_per_credit=all_product_prices["osparc"] + 5, + comment="New price for osparc", + stripe_price_id="stripe-price-id", + stripe_tax_rate_id="stripe-tax-rate-id", + ) + .returning(products_prices.c.usd_per_credit) ) - .returning(products_prices.c.usd_per_credit) - ) assert usd is not None assert usd != all_product_prices["osparc"] return Decimal(usd)