diff --git a/packages/notifications-library/tests/with_db/conftest.py b/packages/notifications-library/tests/with_db/conftest.py index 9dda5da676d3..0ddf0d9f464e 100644 --- a/packages/notifications-library/tests/with_db/conftest.py +++ b/packages/notifications-library/tests/with_db/conftest.py @@ -16,11 +16,14 @@ from models_library.users import UserID from notifications_library._templates import get_default_named_templates from pydantic import validate_call +from pytest_simcore.helpers.postgres_tools import insert_and_get_row_lifespan +from pytest_simcore.helpers.postgres_users import ( + insert_and_get_user_and_secrets_lifespan, +) from simcore_postgres_database.models.jinja2_templates import jinja2_templates from simcore_postgres_database.models.payments_transactions import payments_transactions from simcore_postgres_database.models.products import products from simcore_postgres_database.models.products_to_templates import products_to_templates -from simcore_postgres_database.models.users import users from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio.engine import AsyncEngine @@ -50,16 +53,11 @@ async def user( and injects a user in db """ assert user_id == user["id"] - pk_args = users.c.id, user["id"] - - # NOTE: creation of primary group and setting `groupid`` is automatically triggered after creation of user by postgres - async with sqlalchemy_async_engine.begin() as conn: - row: Row = await _insert_and_get_row(conn, users, user, *pk_args) - - yield row._asdict() - - async with sqlalchemy_async_engine.begin() as conn: - await _delete_row(conn, users, *pk_args) + async with insert_and_get_user_and_secrets_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + sqlalchemy_async_engine, + **user, + ) as row: + yield row @pytest.fixture @@ -82,15 +80,14 @@ async def product( # NOTE: osparc product is already in db. This is another product assert product["name"] != "osparc" - pk_args = products.c.name, product["name"] - - async with sqlalchemy_async_engine.begin() as conn: - row: Row = await _insert_and_get_row(conn, products, product, *pk_args) - - yield row._asdict() - - async with sqlalchemy_async_engine.begin() as conn: - await _delete_row(conn, products, *pk_args) + async with insert_and_get_row_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + sqlalchemy_async_engine, + table=products, + values=product, + pk_col=products.c.name, + pk_value=product["name"], + ) as row: + yield row @pytest.fixture diff --git a/packages/postgres-database/src/simcore_postgres_database/migration/versions/5679165336c8_new_users_secrets.py b/packages/postgres-database/src/simcore_postgres_database/migration/versions/5679165336c8_new_users_secrets.py new file mode 100644 index 000000000000..1187c800a65c --- /dev/null +++ b/packages/postgres-database/src/simcore_postgres_database/migration/versions/5679165336c8_new_users_secrets.py @@ -0,0 +1,77 @@ +"""new users secrets + +Revision ID: 5679165336c8 +Revises: 61b98a60e934 +Create Date: 2025-07-17 17:07:20.200038+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "5679165336c8" +down_revision = "61b98a60e934" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "users_secrets", + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("password_hash", sa.String(), nullable=False), + sa.Column( + "modified", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + name="fk_users_secrets_user_id_users", + onupdate="CASCADE", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("user_id", name="users_secrets_pkey"), + ) + + # Copy password data from users table to users_secrets table + op.execute( + sa.DDL( + """ + INSERT INTO users_secrets (user_id, password_hash, modified) + SELECT id, password_hash, created_at + FROM users + WHERE password_hash IS NOT NULL + """ + ) + ) + + op.drop_column("users", "password_hash") + + +def downgrade(): + # Add column as nullable first + op.add_column( + "users", + sa.Column("password_hash", sa.VARCHAR(), autoincrement=False, nullable=True), + ) + + # Copy password data back from users_secrets table to users table + op.execute( + sa.DDL( + """ + UPDATE users + SET password_hash = us.password_hash + FROM users_secrets us + WHERE users.id = us.user_id + """ + ) + ) + + # Now make the column NOT NULL + op.alter_column("users", "password_hash", nullable=False) + + op.drop_table("users_secrets") diff --git a/packages/postgres-database/src/simcore_postgres_database/models/_common.py b/packages/postgres-database/src/simcore_postgres_database/models/_common.py index 47bfeb6ebf08..6b2405548547 100644 --- a/packages/postgres-database/src/simcore_postgres_database/models/_common.py +++ b/packages/postgres-database/src/simcore_postgres_database/models/_common.py @@ -16,24 +16,28 @@ class RefActions: NO_ACTION: Final[str] = "NO ACTION" -def column_created_datetime(*, timezone: bool = True) -> sa.Column: +def column_created_datetime( + *, timezone: bool = True, doc="Timestamp auto-generated upon creation" +) -> sa.Column: return sa.Column( "created", sa.DateTime(timezone=timezone), nullable=False, server_default=sa.sql.func.now(), - doc="Timestamp auto-generated upon creation", + doc=doc, ) -def column_modified_datetime(*, timezone: bool = True) -> sa.Column: +def column_modified_datetime( + *, timezone: bool = True, doc="Timestamp with last row update" +) -> sa.Column: return sa.Column( "modified", sa.DateTime(timezone=timezone), nullable=False, server_default=sa.sql.func.now(), onupdate=sa.sql.func.now(), - doc="Timestamp with last row update", + doc=doc, ) diff --git a/packages/postgres-database/src/simcore_postgres_database/models/users.py b/packages/postgres-database/src/simcore_postgres_database/models/users.py index 7be2161ff864..62dffd58c66d 100644 --- a/packages/postgres-database/src/simcore_postgres_database/models/users.py +++ b/packages/postgres-database/src/simcore_postgres_database/models/users.py @@ -67,15 +67,6 @@ "NOTE: new policy (NK) is that the same phone can be reused therefore it does not has to be unique", ), # - # User Secrets ------------------ - # - sa.Column( - "password_hash", - sa.String(), - nullable=False, - doc="Hashed password", - ), - # # User Account ------------------ # sa.Column( diff --git a/packages/postgres-database/src/simcore_postgres_database/models/users_secrets.py b/packages/postgres-database/src/simcore_postgres_database/models/users_secrets.py new file mode 100644 index 000000000000..1a1ae04ec637 --- /dev/null +++ b/packages/postgres-database/src/simcore_postgres_database/models/users_secrets.py @@ -0,0 +1,34 @@ +import sqlalchemy as sa + +from ._common import RefActions, column_modified_datetime +from .base import metadata + +__all__: tuple[str, ...] = ("users_secrets",) + +users_secrets = sa.Table( + "users_secrets", + metadata, + # + # User Secrets ------------------ + # + sa.Column( + "user_id", + sa.BigInteger(), + sa.ForeignKey( + "users.id", + name="fk_users_secrets_user_id_users", + onupdate=RefActions.CASCADE, + ondelete=RefActions.CASCADE, + ), + nullable=False, + ), + sa.Column( + "password_hash", + sa.String(), + nullable=False, + doc="Hashed password", + ), + column_modified_datetime(timezone=True, doc="Last password modification timestamp"), + # --------------------------- + sa.PrimaryKeyConstraint("user_id", name="users_secrets_pkey"), +) diff --git a/packages/postgres-database/src/simcore_postgres_database/utils_users.py b/packages/postgres-database/src/simcore_postgres_database/utils_users.py index 587f90ee504b..295806dffa7d 100644 --- a/packages/postgres-database/src/simcore_postgres_database/utils_users.py +++ b/packages/postgres-database/src/simcore_postgres_database/utils_users.py @@ -5,17 +5,21 @@ import re import secrets import string +from dataclasses import dataclass, fields from datetime import datetime from typing import Any, Final import sqlalchemy as sa -from common_library.async_tools import maybe_await from sqlalchemy import Column +from sqlalchemy.engine.result import Row +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine +from sqlalchemy.sql import Select -from ._protocols import DBConnection -from .aiopg_errors import UniqueViolation from .models.users import UserRole, UserStatus, users from .models.users_details import users_pre_registration_details +from .models.users_secrets import users_secrets +from .utils_repos import pass_or_acquire_connection, transaction_context class BaseUserRepoError(Exception): @@ -52,74 +56,124 @@ def generate_alternative_username(username: str) -> str: return f"{username}_{_generate_random_chars()}" +@dataclass(frozen=True) +class UserRow: + id: int + name: str + email: str + role: UserRole + status: UserStatus + first_name: str | None = None + last_name: str | None = None + phone: str | None = None + + @classmethod + def from_row(cls, row: Row) -> "UserRow": + return cls(**{f.name: getattr(row, f.name) for f in fields(cls)}) + + class UsersRepo: - @staticmethod + _user_columns = ( + users.c.id, + users.c.name, + users.c.email, + users.c.role, + users.c.status, + users.c.first_name, + users.c.last_name, + users.c.phone, + ) + + def __init__(self, engine: AsyncEngine): + self._engine = engine + + async def _get_scalar_or_raise( + self, + query: Select, + connection: AsyncConnection | None = None, + ) -> Any: + """Execute a scalar query and raise UserNotFoundInRepoError if no value found.""" + async with pass_or_acquire_connection(self._engine, connection) as conn: + value = await conn.scalar(query) + if value is not None: + return value + raise UserNotFoundInRepoError + async def new_user( - conn: DBConnection, + self, + connection: AsyncConnection | None = None, + *, email: str, password_hash: str, status: UserStatus, expires_at: datetime | None, - ) -> Any: - data: dict[str, Any] = { + role: UserRole = UserRole.USER, + ) -> UserRow: + user_data: dict[str, Any] = { "name": _generate_username_from_email(email), "email": email, - "password_hash": password_hash, "status": status, - "role": UserRole.USER, + "role": role, "expires_at": expires_at, } user_id = None while user_id is None: try: - user_id = await conn.scalar( - users.insert().values(**data).returning(users.c.id) - ) - except UniqueViolation: - data["name"] = generate_alternative_username(data["name"]) - - result = await conn.execute( - sa.select( - users.c.id, - users.c.name, - users.c.email, - users.c.role, - users.c.status, - ).where(users.c.id == user_id) - ) - return await maybe_await(result.first()) + async with transaction_context(self._engine, connection) as conn: + # Insert user record + user_id = await conn.scalar( + users.insert().values(**user_data).returning(users.c.id) + ) + + # Insert password hash into users_secrets table + await conn.execute( + users_secrets.insert().values( + user_id=user_id, + password_hash=password_hash, + ) + ) + except IntegrityError: + user_data["name"] = generate_alternative_username(user_data["name"]) + user_id = None # Reset to retry with new username + + async with pass_or_acquire_connection(self._engine, connection) as conn: + result = await conn.execute( + sa.select(*self._user_columns).where(users.c.id == user_id) + ) + return UserRow.from_row(result.one()) - @staticmethod async def link_and_update_user_from_pre_registration( - conn: DBConnection, + self, + connection: AsyncConnection | None = None, *, new_user_id: int, new_user_email: str, - update_user: bool = True, ) -> None: """After a user is created, it can be associated with information provided during invitation - WARNING: Use ONLY upon new user creation. It might override user_details.user_id, users.first_name, users.last_name etc if already applied - or changes happen in users table + Links ALL pre-registrations for the given email to the user, regardless of product_name. + + WARNING: Use ONLY upon new user creation. It might override user_details.user_id, + users.first_name, users.last_name etc if already applied or changes happen in users table """ assert new_user_email # nosec assert new_user_id > 0 # nosec - # link both tables first - result = await conn.execute( - users_pre_registration_details.update() - .where(users_pre_registration_details.c.pre_email == new_user_email) - .values(user_id=new_user_id) - ) + async with transaction_context(self._engine, connection) as conn: + # Link ALL pre-registrations for this email to the user + result = await conn.execute( + users_pre_registration_details.update() + .where(users_pre_registration_details.c.pre_email == new_user_email) + .values(user_id=new_user_id) + ) - if update_user: # COPIES some pre-registration details to the users table pre_columns = ( users_pre_registration_details.c.pre_first_name, users_pre_registration_details.c.pre_last_name, - # NOTE: pre_phone is not copied since it has to be validated. Otherwise, if - # phone is wrong, currently user won't be able to login! + # NOTE: pre_phone is not copied since it has to be validated. + # Otherwise, if phone is wrong, currently user won't be able to login! ) assert {c.name for c in pre_columns} == { # nosec @@ -133,103 +187,162 @@ async def link_and_update_user_from_pre_registration( and c.name.startswith("pre_") }, "Different pre-cols detected. This code might need an update update" + # Get the most recent pre-registration data to copy to users table result = await conn.execute( - sa.select(*pre_columns).where( - users_pre_registration_details.c.pre_email == new_user_email - ) + sa.select(*pre_columns) + .where(users_pre_registration_details.c.pre_email == new_user_email) + .order_by(users_pre_registration_details.c.created.desc()) + .limit(1) ) - if pre_registration_details_data := result.first(): - # NOTE: could have many products! which to use? + if pre_registration_details_data := result.one_or_none(): await conn.execute( users.update() .where(users.c.id == new_user_id) .values( - first_name=pre_registration_details_data.pre_first_name, # type: ignore[union-attr] - last_name=pre_registration_details_data.pre_last_name, # type: ignore[union-attr] + first_name=pre_registration_details_data.pre_first_name, + last_name=pre_registration_details_data.pre_last_name, ) ) - @staticmethod - def get_billing_details_query(user_id: int): - return ( - sa.select( - users.c.first_name, - users.c.last_name, - users_pre_registration_details.c.institution, - users_pre_registration_details.c.address, - users_pre_registration_details.c.city, - users_pre_registration_details.c.state, - users_pre_registration_details.c.country, - users_pre_registration_details.c.postal_code, - users.c.phone, - ) - .select_from( - users.join( - users_pre_registration_details, - users.c.id == users_pre_registration_details.c.user_id, - ) - ) - .where(users.c.id == user_id) + async def get_role( + self, connection: AsyncConnection | None = None, *, user_id: int + ) -> UserRole: + value = await self._get_scalar_or_raise( + sa.select(users.c.role).where(users.c.id == user_id), + connection=connection, ) - - @staticmethod - async def get_billing_details(conn: DBConnection, user_id: int) -> Any | None: - result = await conn.execute( - UsersRepo.get_billing_details_query(user_id=user_id) + assert isinstance(value, UserRole) # nosec + return UserRole(value) + + async def get_email( + self, connection: AsyncConnection | None = None, *, user_id: int + ) -> str: + value = await self._get_scalar_or_raise( + sa.select(users.c.email).where(users.c.id == user_id), + connection=connection, ) - return await maybe_await(result.fetchone()) + assert isinstance(value, str) # nosec + return value - @staticmethod - async def get_role(conn: DBConnection, user_id: int) -> UserRole: - value: UserRole | None = await conn.scalar( - sa.select(users.c.role).where(users.c.id == user_id) + async def get_active_user_email( + self, connection: AsyncConnection | None = None, *, user_id: int + ) -> str: + value = await self._get_scalar_or_raise( + sa.select(users.c.email).where( + (users.c.status == UserStatus.ACTIVE) & (users.c.id == user_id) + ), + connection=connection, + ) + assert isinstance(value, str) # nosec + return value + + async def get_password_hash( + self, connection: AsyncConnection | None = None, *, user_id: int + ) -> str: + value = await self._get_scalar_or_raise( + sa.select(users_secrets.c.password_hash).where( + users_secrets.c.user_id == user_id + ), + connection=connection, ) - if value: - assert isinstance(value, UserRole) # nosec - return UserRole(value) + assert isinstance(value, str) # nosec + return value - raise UserNotFoundInRepoError + async def get_user_by_email_or_none( + self, connection: AsyncConnection | None = None, *, email: str + ) -> UserRow | None: + async with pass_or_acquire_connection(self._engine, connection) as conn: + result = await conn.execute( + sa.select(*self._user_columns).where(users.c.email == email.lower()) + ) + row = result.one_or_none() + return UserRow.from_row(row) if row else None - @staticmethod - async def get_email(conn: DBConnection, user_id: int) -> str: - value: str | None = await conn.scalar( - sa.select(users.c.email).where(users.c.id == user_id) - ) - if value: - assert isinstance(value, str) # nosec - return value + async def get_user_by_id_or_none( + self, connection: AsyncConnection | None = None, *, user_id: int + ) -> UserRow | None: + async with pass_or_acquire_connection(self._engine, connection) as conn: + result = await conn.execute( + sa.select(*self._user_columns).where(users.c.id == user_id) + ) + row = result.one_or_none() + return UserRow.from_row(row) if row else None - raise UserNotFoundInRepoError + async def update_user_phone( + self, connection: AsyncConnection | None = None, *, user_id: int, phone: str + ) -> None: + async with transaction_context(self._engine, connection) as conn: + await conn.execute( + users.update().where(users.c.id == user_id).values(phone=phone) + ) - @staticmethod - async def get_active_user_email(conn: DBConnection, user_id: int) -> str: - value: str | None = await conn.scalar( - sa.select(users.c.email).where( - (users.c.status == UserStatus.ACTIVE) & (users.c.id == user_id) + async def update_user_password_hash( + self, + connection: AsyncConnection | None = None, + *, + user_id: int, + password_hash: str, + ) -> None: + async with transaction_context(self._engine, connection) as conn: + await self.get_password_hash( + connection=conn, user_id=user_id + ) # ensure user exists + await conn.execute( + users_secrets.update() + .where(users_secrets.c.user_id == user_id) + .values(password_hash=password_hash) ) - ) - if value is not None: - assert isinstance(value, str) # nosec - return value - raise UserNotFoundInRepoError + async def is_email_used( + self, connection: AsyncConnection | None = None, *, email: str + ) -> bool: - @staticmethod - async def is_email_used(conn: DBConnection, email: str) -> bool: - email = email.lower() + async with pass_or_acquire_connection(self._engine, connection) as conn: - registered = await conn.scalar( - sa.select(users.c.id).where(users.c.email == email) - ) - if registered: - return True + email = email.lower() - pre_registered = await conn.scalar( - sa.select(users_pre_registration_details.c.user_id).where( - users_pre_registration_details.c.pre_email == email + registered = await conn.scalar( + sa.select(users.c.id).where(users.c.email == email) ) - ) - return bool(pre_registered) + if registered: + return True + + # Check if email exists in pre-registration, regardless of user_id status + pre_registered = await conn.scalar( + sa.select(users_pre_registration_details.c.id).where( + users_pre_registration_details.c.pre_email == email + ) + ) + return bool(pre_registered) + + async def get_billing_details( + self, connection: AsyncConnection | None = None, *, user_id: int + ) -> Any | None: + async with pass_or_acquire_connection(self._engine, connection) as conn: + result = await conn.execute( + sa.select( + users.c.first_name, + users.c.last_name, + users_pre_registration_details.c.institution, + users_pre_registration_details.c.address, + users_pre_registration_details.c.city, + users_pre_registration_details.c.state, + users_pre_registration_details.c.country, + users_pre_registration_details.c.postal_code, + users.c.phone, + ) + .select_from( + users.join( + users_pre_registration_details, + users.c.id == users_pre_registration_details.c.user_id, + ) + ) + .where(users.c.id == user_id) + .order_by(users_pre_registration_details.c.created.desc()) + .limit(1) + # NOTE: might want to copy billing details to users table?? + ) + return result.one_or_none() # diff --git a/packages/postgres-database/tests/conftest.py b/packages/postgres-database/tests/conftest.py index fdac39729b6b..7ba9695aec1f 100644 --- a/packages/postgres-database/tests/conftest.py +++ b/packages/postgres-database/tests/conftest.py @@ -18,11 +18,10 @@ from aiopg.sa.engine import Engine from aiopg.sa.result import ResultProxy, RowProxy from faker import Faker -from pytest_simcore.helpers import postgres_tools +from pytest_simcore.helpers import postgres_tools, postgres_users from pytest_simcore.helpers.faker_factories import ( random_group, random_project, - random_user, ) from simcore_postgres_database.models.products import products from simcore_postgres_database.models.projects import projects @@ -268,10 +267,11 @@ def create_fake_user(sync_engine: sqlalchemy.engine.Engine) -> Iterator[Callable async def _creator( conn: SAConnection, group: RowProxy | None = None, **overrides ) -> RowProxy: - user_id = await conn.scalar( - users.insert().values(**random_user(**overrides)).returning(users.c.id) + + user_id = await postgres_users.insert_user_and_secrets( + conn, + **overrides, ) - assert user_id is not None # This is done in two executions instead of one (e.g. returning(literal_column("*")) ) # to allow triggering function in db that diff --git a/packages/postgres-database/tests/test_models_api_keys.py b/packages/postgres-database/tests/test_models_api_keys.py index d8863f9ac748..d4852d199d6c 100644 --- a/packages/postgres-database/tests/test_models_api_keys.py +++ b/packages/postgres-database/tests/test_models_api_keys.py @@ -9,10 +9,10 @@ import sqlalchemy as sa from aiopg.sa.connection import SAConnection from aiopg.sa.result import RowProxy +from pytest_simcore.helpers import postgres_users from pytest_simcore.helpers.faker_factories import ( random_api_auth, random_product, - random_user, ) from simcore_postgres_database.models.api_keys import api_keys from simcore_postgres_database.models.products import products @@ -21,13 +21,12 @@ @pytest.fixture async def user_id(connection: SAConnection) -> AsyncIterable[int]: - uid = await connection.scalar( - users.insert().values(random_user()).returning(users.c.id) - ) - assert uid - yield uid + user_id = await postgres_users.insert_user_and_secrets(connection) + + assert user_id + yield user_id - await connection.execute(users.delete().where(users.c.id == uid)) + await connection.execute(users.delete().where(users.c.id == user_id)) @pytest.fixture @@ -84,7 +83,10 @@ async def test_get_session_identity_for_api_server( # authorize a session # result = await connection.execute( - sa.select(api_keys.c.user_id, api_keys.c.product_name,).where( + sa.select( + api_keys.c.user_id, + api_keys.c.product_name, + ).where( (api_keys.c.api_key == session_auth.api_key) & (api_keys.c.api_secret == session_auth.api_secret), ) diff --git a/packages/postgres-database/tests/test_models_groups.py b/packages/postgres-database/tests/test_models_groups.py index 6ce8a77c4cc3..a3c5ad154a30 100644 --- a/packages/postgres-database/tests/test_models_groups.py +++ b/packages/postgres-database/tests/test_models_groups.py @@ -10,7 +10,7 @@ from aiopg.sa.connection import SAConnection from aiopg.sa.result import ResultProxy, RowProxy from psycopg2.errors import ForeignKeyViolation, RaiseException, UniqueViolation -from pytest_simcore.helpers.faker_factories import random_user +from pytest_simcore.helpers import postgres_users from simcore_postgres_database.webserver_models import ( GroupType, groups, @@ -64,9 +64,8 @@ async def test_all_group( await connection.execute(groups.delete().where(groups.c.gid == all_group_gid)) # check adding a user is automatically added to the all group - result = await connection.execute( - users.insert().values(**random_user()).returning(literal_column("*")) - ) + user_id = await postgres_users.insert_user_and_secrets(connection) + result = await connection.execute(users.select().where(users.c.id == user_id)) user: RowProxy = await result.fetchone() result = await connection.execute( @@ -98,14 +97,10 @@ async def test_all_group( async def test_own_group( connection: SAConnection, ): - result = await connection.execute( - users.insert().values(**random_user()).returning(literal_column("*")) - ) - user: RowProxy = await result.fetchone() - assert not user.primary_gid + user_id = await postgres_users.insert_user_and_secrets(connection) # now fetch the same user that shall have a primary group set by the db - result = await connection.execute(users.select().where(users.c.id == user.id)) + result = await connection.execute(users.select().where(users.c.id == user_id)) user: RowProxy = await result.fetchone() assert user.primary_gid diff --git a/packages/postgres-database/tests/test_models_projects_to_jobs.py b/packages/postgres-database/tests/test_models_projects_to_jobs.py index d6f2879694d4..e2e5cf0476e2 100644 --- a/packages/postgres-database/tests/test_models_projects_to_jobs.py +++ b/packages/postgres-database/tests/test_models_projects_to_jobs.py @@ -10,12 +10,12 @@ import sqlalchemy as sa import sqlalchemy.engine import sqlalchemy.exc +from common_library.users_enums import UserRole from faker import Faker from pytest_simcore.helpers import postgres_tools from pytest_simcore.helpers.faker_factories import random_project, random_user from simcore_postgres_database.models.projects import projects from simcore_postgres_database.models.projects_to_jobs import projects_to_jobs -from simcore_postgres_database.models.users import users @pytest.fixture @@ -66,9 +66,24 @@ def test_populate_projects_to_jobs_during_migration( # INSERT data (emulates data in-place) user_data = random_user( - faker, name="test_populate_projects_to_jobs_during_migration" + faker, + name="test_populate_projects_to_jobs_during_migration", + role=UserRole.USER.value, ) - stmt = users.insert().values(**user_data).returning(users.c.id) + user_data["password_hash"] = ( + "password_hash_was_still_here_at_this_migration_commit" # noqa: S105 + ) + + columns = list(user_data.keys()) + values_clause = ", ".join(f":{col}" for col in columns) + columns_clause = ", ".join(columns) + stmt = sa.text( + f""" + INSERT INTO users ({columns_clause}) + VALUES ({values_clause}) + RETURNING id + """ # noqa: S608 + ).bindparams(**user_data) result = conn.execute(stmt) user_id = result.scalar() diff --git a/packages/postgres-database/tests/test_users.py b/packages/postgres-database/tests/test_users.py index 8bfe2814ada1..038f9a53fa64 100644 --- a/packages/postgres-database/tests/test_users.py +++ b/packages/postgres-database/tests/test_users.py @@ -3,35 +3,41 @@ # pylint: disable=unused-argument # pylint: disable=unused-variable +from collections.abc import Iterator from datetime import datetime, timedelta import pytest +import simcore_postgres_database.cli import sqlalchemy as sa -from aiopg.sa.connection import SAConnection -from aiopg.sa.result import ResultProxy, RowProxy +import sqlalchemy.engine +import sqlalchemy.exc from faker import Faker +from pytest_simcore.helpers import postgres_tools from pytest_simcore.helpers.faker_factories import random_user -from simcore_postgres_database.aiopg_errors import ( - InvalidTextRepresentation, - UniqueViolation, -) from simcore_postgres_database.models.users import UserRole, UserStatus, users +from simcore_postgres_database.utils_repos import ( + pass_or_acquire_connection, + transaction_context, +) from simcore_postgres_database.utils_users import ( UsersRepo, _generate_username_from_email, generate_alternative_username, ) +from sqlalchemy.exc import DBAPIError, IntegrityError +from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.sql import func @pytest.fixture -async def clean_users_db_table(connection: SAConnection): +async def clean_users_db_table(asyncpg_engine: AsyncEngine): yield - await connection.execute(users.delete()) + async with transaction_context(asyncpg_engine) as connection: + await connection.execute(users.delete()) async def test_user_status_as_pending( - connection: SAConnection, faker: Faker, clean_users_db_table: None + asyncpg_engine: AsyncEngine, faker: Faker, clean_users_db_table: None ): """Checks a bug where the expression @@ -51,10 +57,13 @@ async def test_user_status_as_pending( # tests that the database never stores the word "PENDING" data = random_user(faker, status="PENDING") assert data["status"] == "PENDING" - with pytest.raises(InvalidTextRepresentation) as err_info: - await connection.execute(users.insert().values(data)) + async with transaction_context(asyncpg_engine) as connection: + with pytest.raises(DBAPIError) as err_info: + await connection.execute(users.insert().values(data)) - assert 'invalid input value for enum userstatus: "PENDING"' in f"{err_info.value}" + assert ( + 'invalid input value for enum userstatus: "PENDING"' in f"{err_info.value}" + ) @pytest.mark.parametrize( @@ -66,27 +75,30 @@ async def test_user_status_as_pending( ) async def test_user_status_inserted_as_enum_or_int( status_value: UserStatus | str, - connection: SAConnection, + asyncpg_engine: AsyncEngine, faker: Faker, clean_users_db_table: None, ): # insert as `status_value` data = random_user(faker, status=status_value) assert data["status"] == status_value - user_id = await connection.scalar(users.insert().values(data).returning(users.c.id)) - # get as UserStatus.CONFIRMATION_PENDING - user = await ( - await connection.execute(users.select().where(users.c.id == user_id)) - ).first() - assert user + async with transaction_context(asyncpg_engine) as connection: + user_id = await connection.scalar( + users.insert().values(data).returning(users.c.id) + ) - assert UserStatus(user.status) == UserStatus.CONFIRMATION_PENDING - assert user.status == UserStatus.CONFIRMATION_PENDING + # get as UserStatus.CONFIRMATION_PENDING + result = await connection.execute(users.select().where(users.c.id == user_id)) + user = result.one_or_none() + assert user + + assert UserStatus(user.status) == UserStatus.CONFIRMATION_PENDING + assert user.status == UserStatus.CONFIRMATION_PENDING async def test_unique_username( - connection: SAConnection, faker: Faker, clean_users_db_table: None + asyncpg_engine: AsyncEngine, faker: Faker, clean_users_db_table: None ): data = random_user( faker, @@ -96,33 +108,39 @@ async def test_unique_username( first_name="Pedro", last_name="Crespo Valero", ) - user_id = await connection.scalar(users.insert().values(data).returning(users.c.id)) - user = await ( - await connection.execute(users.select().where(users.c.id == user_id)) - ).first() - assert user - - assert user.id == user_id - assert user.name == "pcrespov" - - # same name fails - data["email"] = faker.email() - with pytest.raises(UniqueViolation): + async with transaction_context(asyncpg_engine) as connection: + user_id = await connection.scalar( + users.insert().values(data).returning(users.c.id) + ) + result = await connection.execute(users.select().where(users.c.id == user_id)) + user = result.one_or_none() + assert user + + assert user.id == user_id + assert user.name == "pcrespov" + + async with transaction_context(asyncpg_engine) as connection: + # same name fails + data["email"] = faker.email() + with pytest.raises(IntegrityError): + await connection.scalar(users.insert().values(data).returning(users.c.id)) + + async with transaction_context(asyncpg_engine) as connection: + # generate new name + data["name"] = _generate_username_from_email(user.email) + data["email"] = faker.email() await connection.scalar(users.insert().values(data).returning(users.c.id)) - # generate new name - data["name"] = _generate_username_from_email(user.email) - data["email"] = faker.email() - await connection.scalar(users.insert().values(data).returning(users.c.id)) + async with transaction_context(asyncpg_engine) as connection: - # and another one - data["name"] = generate_alternative_username(data["name"]) - data["email"] = faker.email() - await connection.scalar(users.insert().values(data).returning(users.c.id)) + # and another one + data["name"] = generate_alternative_username(data["name"]) + data["email"] = faker.email() + await connection.scalar(users.insert().values(data).returning(users.c.id)) async def test_new_user( - connection: SAConnection, faker: Faker, clean_users_db_table: None + asyncpg_engine: AsyncEngine, faker: Faker, clean_users_db_table: None ): data = { "email": faker.email(), @@ -130,7 +148,8 @@ async def test_new_user( "status": UserStatus.ACTIVE, "expires_at": datetime.utcnow(), } - new_user = await UsersRepo.new_user(connection, **data) + repo = UsersRepo(asyncpg_engine) + new_user = await repo.new_user(**data) assert new_user.email == data["email"] assert new_user.status == data["status"] @@ -140,51 +159,205 @@ async def test_new_user( assert _generate_username_from_email(other_email) == new_user.name other_data = {**data, "email": other_email} - other_user = await UsersRepo.new_user(connection, **other_data) + other_user = await repo.new_user(**other_data) assert other_user.email != new_user.email assert other_user.name != new_user.name - assert await UsersRepo.get_email(connection, other_user.id) == other_user.email - assert await UsersRepo.get_role(connection, other_user.id) == other_user.role - assert ( - await UsersRepo.get_active_user_email(connection, other_user.id) - == other_user.email - ) + async with pass_or_acquire_connection(asyncpg_engine) as connection: + assert ( + await repo.get_email(connection, user_id=other_user.id) == other_user.email + ) + assert await repo.get_role(connection, user_id=other_user.id) == other_user.role + assert ( + await repo.get_active_user_email(connection, user_id=other_user.id) + == other_user.email + ) -async def test_trial_accounts(connection: SAConnection, clean_users_db_table: None): +async def test_trial_accounts(asyncpg_engine: AsyncEngine, clean_users_db_table: None): EXPIRATION_INTERVAL = timedelta(minutes=5) # creates trial user client_now = datetime.utcnow() - user_id: int | None = await connection.scalar( - users.insert() - .values( - **random_user( - status=UserStatus.ACTIVE, - # Using some magic from sqlachemy ... - expires_at=func.now() + EXPIRATION_INTERVAL, + async with transaction_context(asyncpg_engine) as connection: + user_id: int | None = await connection.scalar( + users.insert() + .values( + **random_user( + status=UserStatus.ACTIVE, + # Using some magic from sqlachemy ... + expires_at=func.now() + EXPIRATION_INTERVAL, + ) ) + .returning(users.c.id) ) - .returning(users.c.id) - ) - assert user_id + assert user_id - # check expiration date - result: ResultProxy = await connection.execute( - sa.select(users.c.status, users.c.created_at, users.c.expires_at).where( - users.c.id == user_id + # check expiration date + result = await connection.execute( + sa.select(users.c.status, users.c.created_at, users.c.expires_at).where( + users.c.id == user_id + ) ) + row = result.one_or_none() + assert row + assert row.created_at - client_now < timedelta( + minutes=1 + ), "Difference between server and client now should not differ much" + assert row.expires_at - row.created_at == EXPIRATION_INTERVAL + assert row.status == UserStatus.ACTIVE + + # sets user as expired + await connection.execute( + users.update() + .values(status=UserStatus.EXPIRED) + .where(users.c.id == user_id) + ) + + +@pytest.fixture +def sync_engine_with_migration( + sync_engine: sqlalchemy.engine.Engine, db_metadata: sa.MetaData +) -> Iterator[sqlalchemy.engine.Engine]: + # EXTENDS sync_engine fixture to include cleanup and prepare migration + + # cleanup tables + db_metadata.drop_all(sync_engine) + + # prepare migration upgrade + assert simcore_postgres_database.cli.discover.callback + assert simcore_postgres_database.cli.upgrade.callback + + dsn = sync_engine.url + simcore_postgres_database.cli.discover.callback( + user=dsn.username, + password=dsn.password, + host=dsn.host, + database=dsn.database, + port=dsn.port, ) - row: RowProxy | None = await result.first() - assert row - assert row.created_at - client_now < timedelta( - minutes=1 - ), "Difference between server and client now should not differ much" - assert row.expires_at - row.created_at == EXPIRATION_INTERVAL - assert row.status == UserStatus.ACTIVE - - # sets user as expired - await connection.execute( - users.update().values(status=UserStatus.EXPIRED).where(users.c.id == user_id) - ) + + yield sync_engine + + # cleanup tables + postgres_tools.force_drop_all_tables(sync_engine) + + +def test_users_secrets_migration_upgrade_downgrade( + sync_engine_with_migration: sqlalchemy.engine.Engine, faker: Faker +): + """Tests the migration script that moves password_hash from users to users_secrets table. + + + testing + packages/postgres-database/src/simcore_postgres_database/migration/versions/5679165336c8_new_users_secrets.py + + revision = "5679165336c8" + down_revision = "61b98a60e934" + + + NOTE: all statements in conn.execute(...) must be sa.text(...) since at that migration point the schemas of the + code models might not be the same + """ + assert simcore_postgres_database.cli.discover.callback + assert simcore_postgres_database.cli.upgrade.callback + assert simcore_postgres_database.cli.downgrade.callback + + # UPGRADE just one before 5679165336c8_new_users_secrets.py + simcore_postgres_database.cli.upgrade.callback("61b98a60e934") + + with sync_engine_with_migration.connect() as conn: + # Ensure the users_secrets table does NOT exist yet + with pytest.raises(sqlalchemy.exc.ProgrammingError) as exc_info: + conn.execute( + sa.select(sa.func.count()).select_from(sa.table("users_secrets")) + ).scalar() + assert "psycopg2.errors.UndefinedTable" in f"{exc_info.value}" + + # INSERT users with password hashes (emulates data in-place before migration) + users_data_with_hashed_password = [ + { + **random_user( + faker, + name="user_with_password_1", + email="user1@example.com", + role=UserRole.USER.value, + status=UserStatus.ACTIVE, + ), + "password_hash": "hashed_password_1", # noqa: S106 + }, + { + **random_user( + faker, + name="user_with_password_2", + email="user2@example.com", + role=UserRole.USER.value, + status=UserStatus.ACTIVE, + ), + "password_hash": "hashed_password_2", # noqa: S106 + }, + ] + + inserted_user_ids = [] + for user_data in users_data_with_hashed_password: + columns = ", ".join(user_data.keys()) + values_placeholders = ", ".join(f":{key}" for key in user_data) + user_id = conn.execute( + sa.text( + f"INSERT INTO users ({columns}) VALUES ({values_placeholders}) RETURNING id" # noqa: S608 + ), + user_data, + ).scalar() + inserted_user_ids.append(user_id) + + # Verify password hashes are in users table + result = conn.execute( + sa.text("SELECT id, password_hash FROM users WHERE id = ANY(:user_ids)"), + {"user_ids": inserted_user_ids}, + ).fetchall() + + password_hashes_before = {row.id: row.password_hash for row in result} + assert len(password_hashes_before) == 2 + assert password_hashes_before[inserted_user_ids[0]] == "hashed_password_1" + assert password_hashes_before[inserted_user_ids[1]] == "hashed_password_2" + + # MIGRATE UPGRADE: this should move password hashes to users_secrets + # packages/postgres-database/src/simcore_postgres_database/migration/versions/5679165336c8_new_users_secrets.py + simcore_postgres_database.cli.upgrade.callback("5679165336c8") + + with sync_engine_with_migration.connect() as conn: + # Verify users_secrets table exists and contains the password hashes + result = conn.execute( + sa.text("SELECT user_id, password_hash FROM users_secrets ORDER BY user_id") + ).fetchall() + + # Only users with non-null password hashes should be in users_secrets + assert len(result) == 2 + secrets_data = {row.user_id: row.password_hash for row in result} + assert secrets_data[inserted_user_ids[0]] == "hashed_password_1" + assert secrets_data[inserted_user_ids[1]] == "hashed_password_2" + + # Verify password_hash column is removed from users table + with pytest.raises(sqlalchemy.exc.ProgrammingError) as exc_info: + conn.execute(sa.text("SELECT password_hash FROM users")) + assert "psycopg2.errors.UndefinedColumn" in f"{exc_info.value}" + + # MIGRATE DOWNGRADE: this should move password hashes back to users + simcore_postgres_database.cli.downgrade.callback("61b98a60e934") + + with sync_engine_with_migration.connect() as conn: + # Verify users_secrets table no longer exists + with pytest.raises(sqlalchemy.exc.ProgrammingError) as exc_info: + conn.execute(sa.text("SELECT COUNT(*) FROM users_secrets")).scalar() + assert "psycopg2.errors.UndefinedTable" in f"{exc_info.value}" + + # Verify password hashes are back in users table + result = conn.execute( + sa.text("SELECT id, password_hash FROM users WHERE id = ANY(:user_ids)"), + {"user_ids": inserted_user_ids}, + ).fetchall() + + password_hashes_after = {row.id: row.password_hash for row in result} + assert len(password_hashes_after) == 2 + assert password_hashes_after[inserted_user_ids[0]] == "hashed_password_1" + assert password_hashes_after[inserted_user_ids[1]] == "hashed_password_2" diff --git a/packages/postgres-database/tests/test_users_details.py b/packages/postgres-database/tests/test_users_details.py index e4b6bfeb70fc..077d9774a3cf 100644 --- a/packages/postgres-database/tests/test_users_details.py +++ b/packages/postgres-database/tests/test_users_details.py @@ -257,15 +257,18 @@ async def test_create_and_link_user_from_pre_registration( # Invitation link is clicked and the user is created and linked to the pre-registration async with transaction_context(asyncpg_engine) as connection: # user gets created - new_user = await UsersRepo.new_user( + repo = UsersRepo(asyncpg_engine) + new_user = await repo.new_user( connection, email=pre_email, password_hash="123456", # noqa: S106 status=UserStatus.ACTIVE, expires_at=None, ) - await UsersRepo.link_and_update_user_from_pre_registration( - connection, new_user_id=new_user.id, new_user_email=new_user.email + await repo.link_and_update_user_from_pre_registration( + connection, + new_user_id=new_user.id, + new_user_email=new_user.email, ) # Verify the user was created and linked @@ -291,23 +294,23 @@ async def test_get_billing_details_from_pre_registration( # Create the user async with transaction_context(asyncpg_engine) as connection: - new_user = await UsersRepo.new_user( + repo = UsersRepo(asyncpg_engine) + new_user = await repo.new_user( connection, email=pre_email, password_hash="123456", # noqa: S106 status=UserStatus.ACTIVE, expires_at=None, ) - await UsersRepo.link_and_update_user_from_pre_registration( - connection, new_user_id=new_user.id, new_user_email=new_user.email + await repo.link_and_update_user_from_pre_registration( + connection, + new_user_id=new_user.id, + new_user_email=new_user.email, ) # Get billing details - async with pass_or_acquire_connection(asyncpg_engine) as connection: - invoice_data = await UsersRepo.get_billing_details( - connection, user_id=new_user.id - ) - assert invoice_data is not None + invoice_data = await repo.get_billing_details(user_id=new_user.id) + assert invoice_data is not None # Test UserAddress model conversion user_address = UserAddress.create_from_db(invoice_data) @@ -331,15 +334,18 @@ async def test_update_user_from_pre_registration( # Create the user and link to pre-registration async with transaction_context(asyncpg_engine) as connection: - new_user = await UsersRepo.new_user( + repo = UsersRepo(asyncpg_engine) + new_user = await repo.new_user( connection, email=pre_email, password_hash="123456", # noqa: S106 status=UserStatus.ACTIVE, expires_at=None, ) - await UsersRepo.link_and_update_user_from_pre_registration( - connection, new_user_id=new_user.id, new_user_email=new_user.email + await repo.link_and_update_user_from_pre_registration( + connection, + new_user_id=new_user.id, + new_user_email=new_user.email, ) # Update the user manually @@ -358,8 +364,11 @@ async def test_update_user_from_pre_registration( # Re-link the user to pre-registration, which should override manual updates async with transaction_context(asyncpg_engine) as connection: - await UsersRepo.link_and_update_user_from_pre_registration( - connection, new_user_id=new_user.id, new_user_email=new_user.email + repo = UsersRepo(asyncpg_engine) + await repo.link_and_update_user_from_pre_registration( + connection, + new_user_id=new_user.id, + new_user_email=new_user.email, ) result = await connection.execute( @@ -487,20 +496,24 @@ async def test_user_preregisters_for_multiple_products_with_different_outcomes( assert registrations[1].account_request_reviewed_by == product_owner_user["id"] assert registrations[1].account_request_reviewed_at is not None - # 3.Now create a user account with the approved pre-registration + # 3. Now create a user account and link ALL pre-registrations for this email async with transaction_context(asyncpg_engine) as connection: - new_user = await UsersRepo.new_user( + repo = UsersRepo(asyncpg_engine) + new_user = await repo.new_user( connection, email=user_email, password_hash="123456", # noqa: S106 status=UserStatus.ACTIVE, expires_at=None, ) - await UsersRepo.link_and_update_user_from_pre_registration( - connection, new_user_id=new_user.id, new_user_email=new_user.email + # Link all pre-registrations for this email, regardless of approval status or product + await repo.link_and_update_user_from_pre_registration( + connection, + new_user_id=new_user.id, + new_user_email=new_user.email, ) - # Verify both pre-registrations are linked to the new user + # Verify ALL pre-registrations for this email are linked to the user async with pass_or_acquire_connection(asyncpg_engine) as connection: result = await connection.execute( sa.select( @@ -515,5 +528,17 @@ async def test_user_preregisters_for_multiple_products_with_different_outcomes( registrations = result.fetchall() assert len(registrations) == 2 - # Both registrations should be linked to the same user, regardless of approval status - assert all(reg.user_id == new_user.id for reg in registrations) + # Both pre-registrations should be linked to the user, regardless of approval status + product1_reg = next( + reg for reg in registrations if reg.product_name == product1["name"] + ) + product2_reg = next( + reg for reg in registrations if reg.product_name == product2["name"] + ) + + assert product1_reg.user_id == new_user.id # Linked + assert product2_reg.user_id == new_user.id # Linked + + # Verify approval status is preserved independently of linking + assert product1_reg.account_request_status == AccountRequestStatus.APPROVED + assert product2_reg.account_request_status == AccountRequestStatus.REJECTED diff --git a/packages/postgres-database/tests/test_utils_users.py b/packages/postgres-database/tests/test_utils_users.py index d4a7039f1f3e..0f61ba27ed9d 100644 --- a/packages/postgres-database/tests/test_utils_users.py +++ b/packages/postgres-database/tests/test_utils_users.py @@ -8,14 +8,13 @@ from typing import Any import pytest +import sqlalchemy as sa from faker import Faker -from pytest_simcore.helpers.faker_factories import ( - random_user, +from pytest_simcore.helpers.postgres_users import ( + insert_and_get_user_and_secrets_lifespan, ) -from pytest_simcore.helpers.postgres_tools import ( - insert_and_get_row_lifespan, -) -from simcore_postgres_database.models.users import UserRole, users +from simcore_postgres_database.models.users import UserRole +from simcore_postgres_database.models.users_secrets import users_secrets from simcore_postgres_database.utils_repos import ( pass_or_acquire_connection, ) @@ -28,24 +27,71 @@ async def user( faker: Faker, asyncpg_engine: AsyncEngine, ) -> AsyncIterable[dict[str, Any]]: - async with insert_and_get_row_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + async with insert_and_get_user_and_secrets_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup asyncpg_engine, - table=users, - values=random_user( - faker, - role=faker.random_element(elements=UserRole), - ), - pk_col=users.c.id, - ) as row: - yield row + role=faker.random_element(elements=UserRole), + ) as user_and_secrets_row: + yield user_and_secrets_row async def test_users_repo_get(asyncpg_engine: AsyncEngine, user: dict[str, Any]): - repo = UsersRepo() + repo = UsersRepo(asyncpg_engine) async with pass_or_acquire_connection(asyncpg_engine) as connection: assert await repo.get_email(connection, user_id=user["id"]) == user["email"] assert await repo.get_role(connection, user_id=user["id"]) == user["role"] + assert ( + await repo.get_password_hash(connection, user_id=user["id"]) + == user["password_hash"] + ) + assert ( + await repo.get_active_user_email(connection, user_id=user["id"]) + == user["email"] + ) with pytest.raises(UserNotFoundInRepoError): await repo.get_role(connection, user_id=55) + with pytest.raises(UserNotFoundInRepoError): + await repo.get_email(connection, user_id=55) + with pytest.raises(UserNotFoundInRepoError): + await repo.get_password_hash(connection, user_id=55) + with pytest.raises(UserNotFoundInRepoError): + await repo.get_active_user_email(connection, user_id=55) + + +async def test_update_user_password_hash_updates_modified_column( + asyncpg_engine: AsyncEngine, user: dict[str, Any], faker: Faker +): + repo = UsersRepo(asyncpg_engine) + + async with pass_or_acquire_connection(asyncpg_engine) as connection: + # Get initial modified timestamp + result = await connection.execute( + sa.select(users_secrets.c.modified).where( + users_secrets.c.user_id == user["id"] + ) + ) + initial_modified = result.scalar_one() + + # Update password hash + new_password_hash = faker.password() + await repo.update_user_password_hash( + connection, user_id=user["id"], password_hash=new_password_hash + ) + + # Get updated modified timestamp + result = await connection.execute( + sa.select(users_secrets.c.modified).where( + users_secrets.c.user_id == user["id"] + ) + ) + updated_modified = result.scalar_one() + + # Verify modified timestamp changed + assert updated_modified > initial_modified + + # Verify password hash was actually updated + assert ( + await repo.get_password_hash(connection, user_id=user["id"]) + == new_password_hash + ) diff --git a/packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py b/packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py index 737ba8e89520..8a13ecae3a4d 100644 --- a/packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py +++ b/packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py @@ -3,6 +3,7 @@ # pylint:disable=redefined-outer-name # pylint:disable=no-value-for-parameter +import contextlib from collections.abc import AsyncIterator, Awaitable, Callable, Iterator from typing import Any from uuid import uuid4 @@ -19,7 +20,7 @@ from simcore_postgres_database.models.projects import ProjectType, projects from simcore_postgres_database.models.projects_to_products import projects_to_products from simcore_postgres_database.models.services import services_access_rights -from simcore_postgres_database.models.users import UserRole, UserStatus, users +from simcore_postgres_database.models.users import UserRole, UserStatus from simcore_postgres_database.utils_projects_nodes import ( ProjectNodeCreate, ProjectNodesRepo, @@ -27,44 +28,38 @@ from sqlalchemy.ext.asyncio import AsyncEngine from .helpers.postgres_tools import insert_and_get_row_lifespan +from .helpers.postgres_users import sync_insert_and_get_user_and_secrets_lifespan @pytest.fixture() def create_registered_user( - postgres_db: sa.engine.Engine, faker: Faker + postgres_db: sa.engine.Engine, ) -> Iterator[Callable[..., dict]]: + """Fixture to create a registered user with secrets in the database.""" created_user_ids = [] - def _(**user_kwargs) -> dict[str, Any]: - with postgres_db.connect() as con: - # removes all users before continuing - user_config = { - "id": len(created_user_ids) + 1, - "name": faker.name(), - "email": faker.email(), - "password_hash": faker.password(), - "status": UserStatus.ACTIVE, - "role": UserRole.USER, - } - user_config.update(user_kwargs) + with contextlib.ExitStack() as stack: - con.execute( - users.insert().values(user_config).returning(sa.literal_column("*")) - ) - # this is needed to get the primary_gid correctly - result = con.execute( - sa.select(users).where(users.c.id == user_config["id"]) + def _(**user_kwargs) -> dict[str, Any]: + + user_id = len(created_user_ids) + 1 + user = stack.enter_context( + sync_insert_and_get_user_and_secrets_lifespan( + postgres_db, + status=UserStatus.ACTIVE, + role=UserRole.USER, + id=user_id, + **user_kwargs, + ) ) - user = result.first() - assert user + print(f"--> created {user=}") + assert user["id"] == user_id created_user_ids.append(user["id"]) - return dict(user._asdict()) + return user - yield _ + yield _ - with postgres_db.connect() as con: - con.execute(users.delete().where(users.c.id.in_(created_user_ids))) print(f"<-- deleted users {created_user_ids=}") diff --git a/packages/pytest-simcore/src/pytest_simcore/faker_users_data.py b/packages/pytest-simcore/src/pytest_simcore/faker_users_data.py index 4e59b6db93a4..070087982e7d 100644 --- a/packages/pytest-simcore/src/pytest_simcore/faker_users_data.py +++ b/packages/pytest-simcore/src/pytest_simcore/faker_users_data.py @@ -3,9 +3,9 @@ # pylint: disable=unused-variable # pylint: disable=too-many-arguments """ - Fixtures to produce fake data for a user: - - it is self-consistent - - granular customization by overriding fixtures +Fixtures to produce fake data for a user: + - it is self-consistent + - granular customization by overriding fixtures """ from typing import Any @@ -16,7 +16,11 @@ from models_library.users import UserID from pydantic import EmailStr, TypeAdapter -from .helpers.faker_factories import DEFAULT_TEST_PASSWORD, random_user +from .helpers.faker_factories import ( + DEFAULT_TEST_PASSWORD, + random_user, + random_user_secrets, +) _MESSAGE = ( "If set, it overrides the fake value of `{}` fixture." @@ -125,12 +129,17 @@ def user( user_name: IDStr, user_password: str, ) -> dict[str, Any]: - return random_user( - id=user_id, - email=user_email, - name=user_name, - first_name=user_first_name, - last_name=user_last_name, - password=user_password, - fake=faker, - ) + """NOTE: it returns user data including poassword and password_hash""" + secrets = random_user_secrets(fake=faker, user_id=user_id, password=user_password) + assert secrets["user_id"] == user_id + return { + **random_user( + id=user_id, + email=user_email, + name=user_name, + first_name=user_first_name, + last_name=user_last_name, + fake=faker, + ), + "password_hash": secrets["password_hash"], + } diff --git a/packages/pytest-simcore/src/pytest_simcore/helpers/faker_factories.py b/packages/pytest-simcore/src/pytest_simcore/helpers/faker_factories.py index 5a9b1a5a5d1b..4b09b0ef06b5 100644 --- a/packages/pytest-simcore/src/pytest_simcore/helpers/faker_factories.py +++ b/packages/pytest-simcore/src/pytest_simcore/helpers/faker_factories.py @@ -63,9 +63,7 @@ def _compute_hash(password: str) -> str: _DEFAULT_HASH = _compute_hash(DEFAULT_TEST_PASSWORD) -def random_user( - fake: Faker = DEFAULT_FAKER, password: str | None = None, **overrides -) -> dict[str, Any]: +def random_user(fake: Faker = DEFAULT_FAKER, **overrides) -> dict[str, Any]: from simcore_postgres_database.models.users import users from simcore_postgres_database.webserver_models import UserStatus @@ -75,12 +73,35 @@ def random_user( # NOTE: ensures user name is unique to avoid flaky tests "name": f"{fake.user_name()}_{fake.uuid4()}", "email": f"{fake.uuid4()}_{fake.email().lower()}", - "password_hash": _DEFAULT_HASH, "status": UserStatus.ACTIVE, } + data.update(overrides) assert set(data.keys()).issubset({c.name for c in users.columns}) + return data + + +def random_user_secrets( + fake: Faker = DEFAULT_FAKER, + *, + # foreign keys + user_id: int, + password: str | None = None, + **overrides, +) -> dict[str, Any]: + from simcore_postgres_database.models.users_secrets import users_secrets + + assert fake # nosec + + assert set(overrides.keys()).issubset({c.name for c in users_secrets.columns}) + + data = { + "user_id": user_id, + "password_hash": _DEFAULT_HASH, + } + assert set(data.keys()).issubset({c.name for c in users_secrets.columns}) + # transform password in hash if password: assert len(password) >= 12 diff --git a/packages/pytest-simcore/src/pytest_simcore/helpers/postgres_users.py b/packages/pytest-simcore/src/pytest_simcore/helpers/postgres_users.py new file mode 100644 index 000000000000..dd4039619cd3 --- /dev/null +++ b/packages/pytest-simcore/src/pytest_simcore/helpers/postgres_users.py @@ -0,0 +1,106 @@ +import contextlib + +import sqlalchemy as sa +from simcore_postgres_database.models.users import users +from simcore_postgres_database.models.users_secrets import users_secrets +from sqlalchemy.ext.asyncio import AsyncEngine + +from .faker_factories import random_user, random_user_secrets +from .postgres_tools import ( + insert_and_get_row_lifespan, + sync_insert_and_get_row_lifespan, +) + + +def _get_kwargs_from_overrides(overrides: dict) -> tuple[dict, dict]: + user_kwargs = overrides.copy() + secrets_kwargs = {"password": user_kwargs.pop("password", None)} + if "password_hash" in user_kwargs: + secrets_kwargs["password_hash"] = user_kwargs.pop("password_hash") + return user_kwargs, secrets_kwargs + + +@contextlib.asynccontextmanager +async def insert_and_get_user_and_secrets_lifespan( + sqlalchemy_async_engine: AsyncEngine, **overrides +): + user_kwargs, secrets_kwargs = _get_kwargs_from_overrides(overrides) + + async with contextlib.AsyncExitStack() as stack: + # users + user = await stack.enter_async_context( + insert_and_get_row_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + sqlalchemy_async_engine, + table=users, + values=random_user(**user_kwargs), + pk_col=users.c.id, + ) + ) + + # users_secrets + secrets = await stack.enter_async_context( + insert_and_get_row_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + sqlalchemy_async_engine, + table=users_secrets, + values=random_user_secrets(user_id=user["id"], **secrets_kwargs), + pk_col=users_secrets.c.user_id, + ) + ) + + assert secrets.pop("user_id", None) == user["id"] + + yield {**user, **secrets} + + +@contextlib.contextmanager +def sync_insert_and_get_user_and_secrets_lifespan( + sqlalchemy_sync_engine: sa.engine.Engine, **overrides +): + user_kwargs, secrets_kwargs = _get_kwargs_from_overrides(overrides) + + with contextlib.ExitStack() as stack: + # users + user = stack.enter_context( + sync_insert_and_get_row_lifespan( + sqlalchemy_sync_engine, + table=users, + values=random_user(**user_kwargs), + pk_col=users.c.id, + ) + ) + + # users_secrets + secrets = stack.enter_context( + sync_insert_and_get_row_lifespan( + sqlalchemy_sync_engine, + table=users_secrets, + values=random_user_secrets(user_id=user["id"], **secrets_kwargs), + pk_col=users_secrets.c.user_id, + ) + ) + + assert secrets.pop("user_id", None) == user["id"] + + yield {**user, **secrets} + + +async def insert_user_and_secrets(conn, **overrides) -> int: + # NOTE: DEPRECATED: Legacy adapter. Use insert_and_get_user_and_secrets_lifespan instead + # Temporarily used where conn is produce by aiopg_engine + + user_kwargs, secrets_kwargs = _get_kwargs_from_overrides(overrides) + + # user data + user_id = await conn.scalar( + users.insert().values(**random_user(**user_kwargs)).returning(users.c.id) + ) + assert user_id is not None + + # secrets + await conn.execute( + users_secrets.insert().values( + **random_user_secrets(user_id=user_id, **secrets_kwargs) + ) + ) + + return user_id diff --git a/packages/pytest-simcore/src/pytest_simcore/helpers/webserver_users.py b/packages/pytest-simcore/src/pytest_simcore/helpers/webserver_users.py index 1065df8aecf7..edb3399a14fa 100644 --- a/packages/pytest-simcore/src/pytest_simcore/helpers/webserver_users.py +++ b/packages/pytest-simcore/src/pytest_simcore/helpers/webserver_users.py @@ -3,16 +3,15 @@ from typing import Any, TypedDict from aiohttp import web +from common_library.users_enums import UserRole, UserStatus from models_library.users import UserID -from simcore_postgres_database.models.users import users as users_table -from simcore_service_webserver.db.models import UserRole, UserStatus from simcore_service_webserver.db.plugin import get_asyncpg_engine from simcore_service_webserver.groups import api as groups_service from simcore_service_webserver.products.products_service import list_products from sqlalchemy.ext.asyncio import AsyncEngine -from .faker_factories import DEFAULT_TEST_PASSWORD, random_user -from .postgres_tools import insert_and_get_row_lifespan +from .faker_factories import DEFAULT_TEST_PASSWORD +from .postgres_users import insert_and_get_user_and_secrets_lifespan # WARNING: DO NOT use UserDict is already in https://docs.python.org/3/library/collections.html#collections.UserDictclass UserRowDict(TypedDict): @@ -51,11 +50,8 @@ async def _create_user_in_db( # inject in db user = await exit_stack.enter_async_context( - insert_and_get_row_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup - sqlalchemy_async_engine, - table=users_table, - values=random_user(**data), - pk_col=users_table.c.id, + insert_and_get_user_and_secrets_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + sqlalchemy_async_engine, **data ) ) assert "first_name" in user diff --git a/packages/pytest-simcore/src/pytest_simcore/simcore_storage_data_models.py b/packages/pytest-simcore/src/pytest_simcore/simcore_storage_data_models.py index e897b9ced75e..a41d4876612d 100644 --- a/packages/pytest-simcore/src/pytest_simcore/simcore_storage_data_models.py +++ b/packages/pytest-simcore/src/pytest_simcore/simcore_storage_data_models.py @@ -18,7 +18,8 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine -from .helpers.faker_factories import DEFAULT_FAKER, random_project, random_user +from .helpers.faker_factories import DEFAULT_FAKER, random_project +from .helpers.postgres_users import insert_and_get_user_and_secrets_lifespan @asynccontextmanager @@ -30,19 +31,10 @@ async def _user_context( # NOTE: Ideally this (and next fixture) should be done via webserver API but at this point # in time, the webserver service would bring more dependencies to other services # which would turn this test too complex. - - # pylint: disable=no-value-for-parameter - stmt = users.insert().values(**random_user(name=name)).returning(users.c.id) - async with sqlalchemy_async_engine.begin() as conn: - result = await conn.execute(stmt) - row = result.one() - assert isinstance(row.id, int) - - try: - yield TypeAdapter(UserID).validate_python(row.id) - finally: - async with sqlalchemy_async_engine.begin() as conn: - await conn.execute(users.delete().where(users.c.id == row.id)) + async with insert_and_get_user_and_secrets_lifespan( + sqlalchemy_async_engine, name=name + ) as user: + yield TypeAdapter(UserID).validate_python(user["id"]) @pytest.fixture diff --git a/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py b/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py index 58616a3d3359..6bad10e73711 100644 --- a/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py +++ b/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py @@ -105,7 +105,7 @@ def _handle_unexpected_exception_as_500( return http_error -def _handle_http_error( +def handle_aiohttp_web_http_error( request: web.BaseRequest, exception: web.HTTPError ) -> web.HTTPError: """Handle standard HTTP errors by ensuring they're properly formatted. @@ -156,7 +156,7 @@ def _handle_http_error( return exception -def _handle_http_successful( +def _handle_aiohttp_web_http_successful( request: web.Request, exception: web.HTTPSuccessful ) -> web.HTTPSuccessful: """Handle successful HTTP responses, ensuring they're properly enveloped.""" @@ -217,10 +217,10 @@ async def _middleware_handler(request: web.Request, handler: Handler): result = await handler(request) except web.HTTPError as exc: # 4XX and 5XX raised as exceptions - result = _handle_http_error(request, exc) + result = handle_aiohttp_web_http_error(request, exc) except web.HTTPSuccessful as exc: # 2XX rased as exceptions - result = _handle_http_successful(request, exc) + result = _handle_aiohttp_web_http_successful(request, exc) except web.HTTPRedirection as exc: # 3XX raised as exceptions result = exc diff --git a/packages/service-library/tests/aiohttp/test_rest_middlewares.py b/packages/service-library/tests/aiohttp/test_rest_middlewares.py index d87415858083..26884dbc11cd 100644 --- a/packages/service-library/tests/aiohttp/test_rest_middlewares.py +++ b/packages/service-library/tests/aiohttp/test_rest_middlewares.py @@ -361,13 +361,13 @@ async def test_exception_in_handler_returns_500( ): """Test that exceptions in the handler functions are caught and return 500.""" - # Mock _handle_http_successful to raise an exception + # Mock _handle_aiohttp_web_http_successful to raise an exception def mocked_handler(*args, **kwargs): msg = "Simulated error in handler" raise ValueError(msg) mocker.patch( - "servicelib.aiohttp.rest_middlewares._handle_http_successful", + "servicelib.aiohttp.rest_middlewares._handle_aiohttp_web_http_successful", side_effect=mocked_handler, ) diff --git a/packages/simcore-sdk/tests/integration/conftest.py b/packages/simcore-sdk/tests/integration/conftest.py index 875e322eeb10..c7c755c24d5e 100644 --- a/packages/simcore-sdk/tests/integration/conftest.py +++ b/packages/simcore-sdk/tests/integration/conftest.py @@ -19,6 +19,7 @@ from models_library.users import UserID from pydantic import TypeAdapter from pytest_simcore.helpers.faker_factories import random_project, random_user +from pytest_simcore.helpers.postgres_tools import sync_insert_and_get_row_lifespan from settings_library.aws_s3_cli import AwsS3CliSettings from settings_library.r_clone import RCloneSettings, S3Provider from settings_library.s3 import S3Settings @@ -40,18 +41,16 @@ def user_id(postgres_db: sa.engine.Engine) -> Iterable[UserID]: # which would turn this test too complex. # pylint: disable=no-value-for-parameter - with postgres_db.connect() as conn: - result = conn.execute( - users.insert().values(**random_user(name="test")).returning(users.c.id) - ) - row = result.first() - assert row - usr_id = row[users.c.id] - - yield usr_id - - with postgres_db.connect() as conn: - conn.execute(users.delete().where(users.c.id == usr_id)) + with sync_insert_and_get_row_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + postgres_db, + table=users, + values=random_user( + name="test", + ), + pk_col=users.c.id, + ) as user_row: + + yield user_row["id"] @pytest.fixture diff --git a/services/catalog/tests/unit/with_dbs/conftest.py b/services/catalog/tests/unit/with_dbs/conftest.py index 15bc8ac5b74b..a37aaf479302 100644 --- a/services/catalog/tests/unit/with_dbs/conftest.py +++ b/services/catalog/tests/unit/with_dbs/conftest.py @@ -22,13 +22,15 @@ from pytest_simcore.helpers.faker_factories import ( random_service_access_rights, random_service_meta_data, - random_user, ) from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict from pytest_simcore.helpers.postgres_tools import ( PostgresTestConfig, insert_and_get_row_lifespan, ) +from pytest_simcore.helpers.postgres_users import ( + insert_and_get_user_and_secrets_lifespan, +) from pytest_simcore.helpers.typing_env import EnvVarsDict from simcore_postgres_database.models.groups import groups from simcore_postgres_database.models.products import products @@ -36,7 +38,6 @@ services_access_rights, services_meta_data, ) -from simcore_postgres_database.models.users import users from simcore_service_catalog.core.settings import ApplicationSettings from sqlalchemy import sql from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -151,12 +152,9 @@ async def user( injects a user in db """ assert user_id == user["id"] - async with insert_and_get_row_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + async with insert_and_get_user_and_secrets_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup sqlalchemy_async_engine, - table=users, - values=user, - pk_col=users.c.id, - pk_value=user["id"], + **user, ) as row: yield row @@ -165,16 +163,14 @@ async def user( async def other_user( user_id: UserID, sqlalchemy_async_engine: AsyncEngine, - faker: Faker, ) -> AsyncIterator[dict[str, Any]]: - - _other_user = random_user(fake=faker, id=user_id + 1) - async with insert_and_get_row_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + """ + injects a other user in db (!= user) + """ + async with insert_and_get_user_and_secrets_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup sqlalchemy_async_engine, - table=users, - values=_other_user, - pk_col=users.c.id, - pk_value=_other_user["id"], + name="other_user", + id=user_id + 1, ) as row: yield row diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/users.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/users.py index 80118e2f1b6a..647f8bd6ccc6 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/users.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/users.py @@ -1,3 +1,5 @@ +from typing import cast + from models_library.users import UserID from pydantic import EmailStr, TypeAdapter from simcore_postgres_database.models.users import UserRole @@ -7,11 +9,12 @@ class UsersRepository(BaseRepository): + def _repo(self): + return UsersRepo(self.db_engine) + async def get_user_email(self, user_id: UserID) -> EmailStr: - async with self.db_engine.connect() as conn: - email = await UsersRepo.get_email(conn, user_id) - return TypeAdapter(EmailStr).validate_python(email) + email = await self._repo().get_email(user_id=user_id) + return TypeAdapter(EmailStr).validate_python(email) async def get_user_role(self, user_id: UserID) -> UserRole: - async with self.db_engine.connect() as conn: - return await UsersRepo().get_role(conn, user_id=user_id) + return cast(UserRole, await self._repo().get_role(user_id=user_id)) diff --git a/services/dynamic-scheduler/tests/unit/test_repository_postgres_networks.py b/services/dynamic-scheduler/tests/unit/test_repository_postgres_networks.py index e0374fb31dc4..9ed34d603d42 100644 --- a/services/dynamic-scheduler/tests/unit/test_repository_postgres_networks.py +++ b/services/dynamic-scheduler/tests/unit/test_repository_postgres_networks.py @@ -17,9 +17,11 @@ PostgresTestConfig, insert_and_get_row_lifespan, ) +from pytest_simcore.helpers.postgres_users import ( + insert_and_get_user_and_secrets_lifespan, +) from pytest_simcore.helpers.typing_env import EnvVarsDict from simcore_postgres_database.models.projects import projects -from simcore_postgres_database.models.users import users from simcore_service_dynamic_scheduler.repository.events import ( get_project_networks_repo, ) @@ -77,17 +79,14 @@ async def user_in_db( user_id: UserID, ) -> AsyncIterator[dict[str, Any]]: """ - injects a user in db + injects a user + secrets in db """ assert user_id == user["id"] - async with insert_and_get_row_lifespan( + async with insert_and_get_user_and_secrets_lifespan( engine, - table=users, - values=user, - pk_col=users.c.id, - pk_value=user["id"], - ) as row: - yield row + **user, + ) as user_row: + yield user_row @pytest.fixture diff --git a/services/dynamic-sidecar/tests/integration/conftest.py b/services/dynamic-sidecar/tests/integration/conftest.py index 98ba076f2604..5972315d910d 100644 --- a/services/dynamic-sidecar/tests/integration/conftest.py +++ b/services/dynamic-sidecar/tests/integration/conftest.py @@ -4,6 +4,7 @@ import sqlalchemy as sa from models_library.users import UserID from pytest_simcore.helpers.faker_factories import random_user +from pytest_simcore.helpers.postgres_tools import sync_insert_and_get_row_lifespan from simcore_postgres_database.models.users import users pytest_plugins = [ @@ -24,15 +25,13 @@ def user_id(postgres_db: sa.engine.Engine) -> Iterable[UserID]: # which would turn this test too complex. # pylint: disable=no-value-for-parameter - stmt = users.insert().values(**random_user(name="test")).returning(users.c.id) - print(f"{stmt}") - with postgres_db.connect() as conn: - result = conn.execute(stmt) - row = result.first() - assert row - usr_id = row[users.c.id] - - yield usr_id - - with postgres_db.connect() as conn: - conn.execute(users.delete().where(users.c.id == usr_id)) + with sync_insert_and_get_row_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + postgres_db, + table=users, + values=random_user( + name="test", + ), + pk_col=users.c.id, + ) as user_row: + + yield user_row["id"] diff --git a/services/efs-guardian/tests/unit/test_efs_removal_policy_task.py b/services/efs-guardian/tests/unit/test_efs_removal_policy_task.py index 4000fab0c886..29673ef668d3 100644 --- a/services/efs-guardian/tests/unit/test_efs_removal_policy_task.py +++ b/services/efs-guardian/tests/unit/test_efs_removal_policy_task.py @@ -17,9 +17,11 @@ from models_library.users import UserID from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict from pytest_simcore.helpers.postgres_tools import insert_and_get_row_lifespan +from pytest_simcore.helpers.postgres_users import ( + insert_and_get_user_and_secrets_lifespan, +) from pytest_simcore.helpers.typing_env import EnvVarsDict from simcore_postgres_database.models.projects import projects -from simcore_postgres_database.models.users import users from simcore_postgres_database.utils_repos import transaction_context from simcore_service_efs_guardian.core.settings import ( ApplicationSettings, @@ -71,12 +73,9 @@ async def user_in_db( injects a user in db """ assert user_id == user["id"] - async with insert_and_get_row_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + async with insert_and_get_user_and_secrets_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup app.state.engine, - table=users, - values=user, - pk_col=users.c.id, - pk_value=user["id"], + **user, ) as row: yield row diff --git a/services/payments/tests/unit/test_db_payments_users_repo.py b/services/payments/tests/unit/test_db_payments_users_repo.py index 4cff0108033d..4f63a17f4431 100644 --- a/services/payments/tests/unit/test_db_payments_users_repo.py +++ b/services/payments/tests/unit/test_db_payments_users_repo.py @@ -14,10 +14,12 @@ from models_library.users import UserID from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict from pytest_simcore.helpers.postgres_tools import insert_and_get_row_lifespan +from pytest_simcore.helpers.postgres_users import ( + insert_and_get_user_and_secrets_lifespan, +) from pytest_simcore.helpers.typing_env import EnvVarsDict from simcore_postgres_database.models.payments_transactions import payments_transactions from simcore_postgres_database.models.products import products -from simcore_postgres_database.models.users import users from simcore_service_payments.db.payment_users_repo import PaymentsUsersRepo from simcore_service_payments.services.postgres import get_engine @@ -60,14 +62,10 @@ async def user( injects a user in db """ assert user_id == user["id"] - async with insert_and_get_row_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup - get_engine(app), - table=users, - values=user, - pk_col=users.c.id, - pk_value=user["id"], - ) as row: - yield row + async with insert_and_get_user_and_secrets_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + get_engine(app), **user + ) as user_row: + yield user_row @pytest.fixture diff --git a/services/web/server/src/simcore_service_webserver/garbage_collector/_core_guests.py b/services/web/server/src/simcore_service_webserver/garbage_collector/_core_guests.py index f00e0133b50d..02d462ac3618 100644 --- a/services/web/server/src/simcore_service_webserver/garbage_collector/_core_guests.py +++ b/services/web/server/src/simcore_service_webserver/garbage_collector/_core_guests.py @@ -159,7 +159,7 @@ async def remove_guest_user_with_all_its_resources( "Deleting user %s because it is a GUEST", f"{user_id=}", ) - await users_service.delete_user_without_projects(app, user_id) + await users_service.delete_user_without_projects(app, user_id=user_id) except ( DatabaseError, diff --git a/services/web/server/src/simcore_service_webserver/login/_auth_service.py b/services/web/server/src/simcore_service_webserver/login/_auth_service.py index ef0ba4893387..2d806495402e 100644 --- a/services/web/server/src/simcore_service_webserver/login/_auth_service.py +++ b/services/web/server/src/simcore_service_webserver/login/_auth_service.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any +from typing import TypedDict from aiohttp import web from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON @@ -12,14 +12,50 @@ from ..products.models import Product from ..security import security_service from . import _login_service -from ._login_repository_legacy import AsyncpgStorage, get_plugin_storage -from .constants import MSG_UNKNOWN_EMAIL, MSG_WRONG_PASSWORD +from .constants import MSG_UNKNOWN_EMAIL +from .errors import WrongPasswordError -async def get_user_by_email(app: web.Application, *, email: str) -> dict[str, Any]: - db: AsyncpgStorage = get_plugin_storage(app) - user = await db.get_user({"email": email}) - return dict(user) if user else {} +class UserInfoDict(TypedDict): + id: int + name: str + email: str + role: str + status: str + first_name: str | None + last_name: str | None + phone: str | None + + +async def get_user_or_none( + app: web.Application, *, email: str | None = None, user_id: int | None = None +) -> UserInfoDict | None: + if email is None and user_id is None: + msg = "Either email or user_id must be provided" + raise ValueError(msg) + + asyncpg_engine = get_asyncpg_engine(app) + repo = UsersRepo(asyncpg_engine) + + if email is not None: + user_row = await repo.get_user_by_email_or_none(email=email.lower()) + else: + assert user_id is not None + user_row = await repo.get_user_by_id_or_none(user_id=user_id) + + if user_row is None: + return None + + return UserInfoDict( + id=user_row.id, + name=user_row.name, + email=user_row.email, + role=user_row.role.value, + status=user_row.status.value, + first_name=user_row.first_name, + last_name=user_row.last_name, + phone=user_row.phone, + ) async def create_user( @@ -29,59 +65,131 @@ async def create_user( password: str, status_upon_creation: UserStatus, expires_at: datetime | None, -) -> dict[str, Any]: +) -> UserInfoDict: - async with transaction_context(get_asyncpg_engine(app)) as conn: - user = await UsersRepo.new_user( + asyncpg_engine = get_asyncpg_engine(app) + repo = UsersRepo(asyncpg_engine) + async with transaction_context(asyncpg_engine) as conn: + user_row = await repo.new_user( conn, email=email, password_hash=security_service.encrypt_password(password), status=status_upon_creation, expires_at=expires_at, ) - await UsersRepo.link_and_update_user_from_pre_registration( - conn, new_user_id=user.id, new_user_email=user.email + await repo.link_and_update_user_from_pre_registration( + conn, + new_user_id=user_row.id, + new_user_email=user_row.email, ) - return dict(user._mapping) # pylint: disable=protected-access # noqa: SLF001 + return UserInfoDict( + id=user_row.id, + name=user_row.name, + email=user_row.email, + role=user_row.role.value, + status=user_row.status.value, + first_name=user_row.first_name, + last_name=user_row.last_name, + phone=user_row.phone, + ) -async def check_authorized_user_credentials_or_raise( - user: dict[str, Any], - password: str, - product: Product, -) -> dict: - +def check_not_null_user(user: UserInfoDict | None) -> UserInfoDict: if not user: raise web.HTTPUnauthorized( text=MSG_UNKNOWN_EMAIL, content_type=MIMETYPE_APPLICATION_JSON ) + return user - _login_service.validate_user_status(user=user, support_email=product.support_email) - if not security_service.check_password(password, user["password_hash"]): - raise web.HTTPUnauthorized( - text=MSG_WRONG_PASSWORD, content_type=MIMETYPE_APPLICATION_JSON - ) +async def check_authorized_user_credentials( + app: web.Application, + user: UserInfoDict | None, + *, + password: str, + product: Product, +) -> UserInfoDict: + """ + + Raises: + WrongPasswordError: when password is invalid + web.HTTPUnauthorized: 401 + + Returns: + user info dict + """ + + user = check_not_null_user(user) + + _login_service.validate_user_access( + user_status=user["status"], + user_role=user["role"], + support_email=product.support_email, + ) + + repo = UsersRepo(get_asyncpg_engine(app)) + + if not security_service.check_password( + password, password_hash=await repo.get_password_hash(user_id=user["id"]) + ): + raise WrongPasswordError(user_id=user["id"], product_name=product.name) return user -async def check_authorized_user_in_product_or_raise( +async def check_authorized_user_in_product( app: web.Application, *, - user: dict, + user_email: str, product: Product, ) -> None: - """Checks whether user is registered in this product""" - email = user.get("email", "").lower() + """Checks whether user is registered in this product + + + Raises: + web.HTTPUnauthorized: 401 + """ + product_group_id = product.group_id assert product_group_id is not None # nosec if ( product_group_id is not None and not await groups_service.is_user_by_email_in_group( - app, user_email=email, group_id=product_group_id + app, user_email=user_email, group_id=product_group_id ) ): - raise web.HTTPUnauthorized( - text=MSG_UNKNOWN_EMAIL, content_type=MIMETYPE_APPLICATION_JSON - ) + raise web.HTTPUnauthorized(text=MSG_UNKNOWN_EMAIL) + + +async def update_user_password( + app: web.Application, + *, + user_id: int, + current_password: str, + new_password: str, + verify_current_password: bool = True, +) -> None: + """Updates user password after verifying current password + + Keyword Arguments: + verify_current_password -- whether to check current_password is valid (default: {True}) + + Raises: + WrongPasswordError: when current password is invalid + """ + + repo = UsersRepo(get_asyncpg_engine(app)) + + if verify_current_password: + # Get current password hash + current_password_hash = await repo.get_password_hash(user_id=user_id) + + # Verify current password + if not security_service.check_password(current_password, current_password_hash): + raise WrongPasswordError(user_id=user_id) + + # Encrypt new password and update + new_password_hash = security_service.encrypt_password(new_password) + await repo.update_user_password_hash( + user_id=user_id, password_hash=new_password_hash + ) diff --git a/services/web/server/src/simcore_service_webserver/login/_controller/rest/_rest_exceptions.py b/services/web/server/src/simcore_service_webserver/login/_controller/rest/_rest_exceptions.py index 1878ba449516..7e6e58887656 100644 --- a/services/web/server/src/simcore_service_webserver/login/_controller/rest/_rest_exceptions.py +++ b/services/web/server/src/simcore_service_webserver/login/_controller/rest/_rest_exceptions.py @@ -1,5 +1,7 @@ +from aiohttp import web from common_library.user_messages import user_message from servicelib.aiohttp import status +from servicelib.aiohttp.rest_middlewares import handle_aiohttp_web_http_error from ....exception_handling import ( ExceptionToHttpErrorMap, @@ -8,8 +10,12 @@ to_exceptions_handlers_map, ) from ....users.exceptions import AlreadyPreRegisteredError -from ...constants import MSG_2FA_UNAVAILABLE -from ...errors import SendingVerificationEmailError, SendingVerificationSmsError +from ...constants import MSG_2FA_UNAVAILABLE, MSG_WRONG_PASSWORD +from ...errors import ( + SendingVerificationEmailError, + SendingVerificationSmsError, + WrongPasswordError, +) _TO_HTTP_ERROR_MAP: ExceptionToHttpErrorMap = { AlreadyPreRegisteredError: HttpErrorInfo( @@ -30,6 +36,26 @@ } +async def _handle_legacy_error_response(request: web.Request, exception: Exception): + """ + This handlers keeps compatibility with error responses that include deprecated + `ErrorGet.errors` field + + SEE packages/models-library/src/models_library/rest_error.py + """ + assert isinstance( # nosec + exception, WrongPasswordError + ), f"Expected WrongPasswordError, got {type(exception)}" + + return handle_aiohttp_web_http_error( + request=request, + exception=web.HTTPUnauthorized(text=MSG_WRONG_PASSWORD), + ) + + handle_rest_requests_exceptions = exception_handling_decorator( - to_exceptions_handlers_map(_TO_HTTP_ERROR_MAP) + { + **to_exceptions_handlers_map(_TO_HTTP_ERROR_MAP), + WrongPasswordError: _handle_legacy_error_response, + }, ) diff --git a/services/web/server/src/simcore_service_webserver/login/_controller/rest/auth.py b/services/web/server/src/simcore_service_webserver/login/_controller/rest/auth.py index 0c232c61fd82..594d4f406373 100644 --- a/services/web/server/src/simcore_service_webserver/login/_controller/rest/auth.py +++ b/services/web/server/src/simcore_service_webserver/login/_controller/rest/auth.py @@ -73,13 +73,18 @@ async def login(request: web.Request): login_data = await parse_request_body_as(LoginBody, request) # Authenticate user and verify access to the product - user = await _auth_service.check_authorized_user_credentials_or_raise( - user=await _auth_service.get_user_by_email(request.app, email=login_data.email), + user = await _auth_service.get_user_or_none(request.app, email=login_data.email) + + user = _auth_service.check_not_null_user(user) + + user = await _auth_service.check_authorized_user_credentials( + request.app, + user, password=login_data.password.get_secret_value(), product=product, ) - await _auth_service.check_authorized_user_in_product_or_raise( - request.app, user=user, product=product + await _auth_service.check_authorized_user_in_product( + request.app, user_email=user["email"], product=product ) # Check if user role allows skipping 2FA or if 2FA is not required @@ -150,7 +155,7 @@ async def login(request: web.Request): twilio_auth=settings.LOGIN_TWILIO, twilio_messaging_sid=product.twilio_messaging_sid, twilio_alpha_numeric_sender=product.twilio_alpha_numeric_sender_id, - first_name=user["first_name"], + first_name=user["first_name"] or user["name"], user_id=user["id"], ) @@ -227,8 +232,9 @@ async def login_2fa(request: web.Request): reason=MSG_WRONG_2FA_CODE__INVALID, content_type=MIMETYPE_APPLICATION_JSON ) - user = await _auth_service.get_user_by_email(request.app, email=login_2fa_.email) - assert user is not None # nosec + user = _auth_service.check_not_null_user( + await _auth_service.get_user_or_none(request.app, email=login_2fa_.email) + ) # NOTE: a priviledge user should not have called this entrypoint assert UserRole(user["role"]) <= UserRole.USER # nosec @@ -236,7 +242,7 @@ async def login_2fa(request: web.Request): # dispose since code was used await _twofa_service.delete_2fa_code(request.app, login_2fa_.email) - return await _security_service.login_granted_response(request, user=dict(user)) + return await _security_service.login_granted_response(request, user=user) @routes.post(f"/{API_VTAG}/auth/logout", name="auth_logout") diff --git a/services/web/server/src/simcore_service_webserver/login/_controller/rest/change.py b/services/web/server/src/simcore_service_webserver/login/_controller/rest/change.py index 57842919ec11..50a412f3c771 100644 --- a/services/web/server/src/simcore_service_webserver/login/_controller/rest/change.py +++ b/services/web/server/src/simcore_service_webserver/login/_controller/rest/change.py @@ -5,25 +5,23 @@ from servicelib.aiohttp.requests_validation import parse_request_body_as from servicelib.logging_errors import create_troubleshootting_log_kwargs from servicelib.request_keys import RQT_USERID_KEY -from simcore_postgres_database.utils_repos import pass_or_acquire_connection from simcore_postgres_database.utils_users import UsersRepo from ...._meta import API_VTAG from ....db.plugin import get_asyncpg_engine from ....products import products_web from ....products.models import Product -from ....security import security_service from ....users import users_service from ....utils import HOUR from ....utils_rate_limiting import global_rate_limit_route from ....web_utils import flash_response -from ... import _confirmation_service, _confirmation_web +from ... import _auth_service, _confirmation_service, _confirmation_web from ..._emails_service import get_template_path, send_email_from_template from ..._login_repository_legacy import AsyncpgStorage, get_plugin_storage from ..._login_service import ( ACTIVE, CHANGE_EMAIL, - validate_user_status, + validate_user_access, ) from ...constants import ( MSG_CANT_SEND_MAIL, @@ -34,6 +32,7 @@ MSG_WRONG_PASSWORD, ) from ...decorators import login_required +from ...errors import WrongPasswordError from ...settings import LoginOptions, get_plugin_options from .change_schemas import ChangeEmailBody, ChangePasswordBody, ResetPasswordBody @@ -120,7 +119,7 @@ def _get_error_context( ok = True # CHECK user exists - user = await db.get_user({"email": request_body.email}) + user = await _auth_service.get_user_or_none(request.app, email=request_body.email) if not user: _logger.warning( **create_troubleshootting_log_kwargs( @@ -137,7 +136,11 @@ def _get_error_context( # CHECK user state try: - validate_user_status(user=dict(user), support_email=product.support_email) + validate_user_access( + user_status=user["status"], + user_role=user["role"], + support_email=product.support_email, + ) except web.HTTPError as err: # NOTE: we abuse here (untiby reusing `validate_user_status` and catching http errors that we # do not want to forward but rather log due to the special rules in this entrypoint @@ -222,15 +225,17 @@ async def initiate_change_email(request: web.Request): request_body = await parse_request_body_as(ChangeEmailBody, request) - user = await db.get_user({"id": request[RQT_USERID_KEY]}) + user = await _auth_service.get_user_or_none( + request.app, user_id=request[RQT_USERID_KEY] + ) assert user # nosec if user["email"] == request_body.email: return flash_response("Email changed") - async with pass_or_acquire_connection(get_asyncpg_engine(request.app)) as conn: - if await UsersRepo.is_email_used(conn, email=request_body.email): - raise web.HTTPUnprocessableEntity(text="This email cannot be used") + repo = UsersRepo(get_asyncpg_engine(request.app)) + if await repo.is_email_used(email=request_body.email): + raise web.HTTPUnprocessableEntity(text="This email cannot be used") # Reset if previously requested confirmation = await db.get_confirmation({"user": user, "action": CHANGE_EMAIL}) @@ -266,24 +271,26 @@ async def initiate_change_email(request: web.Request): @login_required async def change_password(request: web.Request): - db: AsyncpgStorage = get_plugin_storage(request.app) passwords = await parse_request_body_as(ChangePasswordBody, request) + user_id = request[RQT_USERID_KEY] + user = await _auth_service.get_user_or_none(request.app, user_id=user_id) - user = await db.get_user({"id": request[RQT_USERID_KEY]}) - assert user # nosec - - if not security_service.check_password( - passwords.current.get_secret_value(), user["password_hash"] - ): - raise web.HTTPUnprocessableEntity(text=MSG_WRONG_PASSWORD) # 422 - - await db.update_user( - dict(user), - { - "password_hash": security_service.encrypt_password( - passwords.new.get_secret_value() - ) - }, + try: + await _auth_service.check_authorized_user_credentials( + request.app, + user=user, + password=passwords.current.get_secret_value(), + product=products_web.get_current_product(request), + ) + except WrongPasswordError as err: + raise web.HTTPUnprocessableEntity(text=MSG_WRONG_PASSWORD) from err + + await _auth_service.update_user_password( + request.app, + user_id=user_id, + current_password=passwords.current.get_secret_value(), + new_password=passwords.new.get_secret_value(), + verify_current_password=False, ) return flash_response(MSG_PASSWORD_CHANGED) diff --git a/services/web/server/src/simcore_service_webserver/login/_controller/rest/confirmation.py b/services/web/server/src/simcore_service_webserver/login/_controller/rest/confirmation.py index 986919365bb8..1f05e965d686 100644 --- a/services/web/server/src/simcore_service_webserver/login/_controller/rest/confirmation.py +++ b/services/web/server/src/simcore_service_webserver/login/_controller/rest/confirmation.py @@ -15,18 +15,22 @@ ) from servicelib.logging_errors import create_troubleshootting_log_kwargs from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON -from simcore_postgres_database.aiopg_errors import UniqueViolation from yarl import URL from ....products import products_web from ....products.models import Product -from ....security import security_service from ....session.access_policies import session_access_required from ....utils import HOUR, MINUTE from ....utils_aiohttp import create_redirect_to_page_response from ....utils_rate_limiting import global_rate_limit_route from ....web_utils import flash_response -from ... import _confirmation_service, _security_service, _twofa_service +from ... import ( + _auth_service, + _confirmation_service, + _registration_service, + _security_service, + _twofa_service, +) from ..._login_repository_legacy import ( AsyncpgStorage, ConfirmationTokenDict, @@ -207,8 +211,6 @@ async def phone_confirmation(request: web.Request): request.app, product_name=product.name ) - db: AsyncpgStorage = get_plugin_storage(request.app) - if not settings.LOGIN_2FA_REQUIRED: raise web.HTTPServiceUnavailable( text="Phone registration is not available", @@ -223,19 +225,15 @@ async def phone_confirmation(request: web.Request): # consumes code await _twofa_service.delete_2fa_code(request.app, request_body.email) - # updates confirmed phone number - try: - user = await db.get_user({"email": request_body.email}) - assert user is not None # nosec - await db.update_user(dict(user), {"phone": request_body.phone}) - - except UniqueViolation as err: - raise web.HTTPUnauthorized( - text="Invalid phone number", - content_type=MIMETYPE_APPLICATION_JSON, - ) from err + user = _auth_service.check_not_null_user( + await _auth_service.get_user_or_none(request.app, email=request_body.email) + ) - return await _security_service.login_granted_response(request, user=dict(user)) + await _registration_service.register_user_phone( + request.app, user_id=user["id"], user_phone=request_body.phone + ) + + return await _security_service.login_granted_response(request, user=user) # fails because of invalid or no code raise web.HTTPUnauthorized( @@ -263,17 +261,19 @@ async def complete_reset_password(request: web.Request): ) if confirmation: - user = await db.get_user({"id": confirmation["user_id"]}) + user = await _auth_service.get_user_or_none( + request.app, user_id=confirmation["user_id"] + ) assert user # nosec - await db.update_user( - user={"id": user["id"]}, - updates={ - "password_hash": security_service.encrypt_password( - request_body.password.get_secret_value() - ) - }, + await _auth_service.update_user_password( + request.app, + user_id=user["id"], + current_password="", + new_password=request_body.password.get_secret_value(), + verify_current_password=False, # confirmed by code ) + await db.delete_confirmation(confirmation) return flash_response(MSG_PASSWORD_CHANGED) @@ -282,5 +282,4 @@ async def complete_reset_password(request: web.Request): text=MSG_PASSWORD_CHANGE_NOT_ALLOWED.format( support_email=product.support_email ), - content_type=MIMETYPE_APPLICATION_JSON, ) # 401 diff --git a/services/web/server/src/simcore_service_webserver/login/_controller/rest/registration.py b/services/web/server/src/simcore_service_webserver/login/_controller/rest/registration.py index b58635817fec..796a47241565 100644 --- a/services/web/server/src/simcore_service_webserver/login/_controller/rest/registration.py +++ b/services/web/server/src/simcore_service_webserver/login/_controller/rest/registration.py @@ -185,9 +185,10 @@ async def register(request: web.Request): ).replace(tzinfo=None) # get authorized user or create new - user = await _auth_service.get_user_by_email(request.app, email=registration.email) + user = await _auth_service.get_user_or_none(request.app, email=registration.email) if user: - await _auth_service.check_authorized_user_credentials_or_raise( + await _auth_service.check_authorized_user_credentials( + request.app, user, password=registration.password.get_secret_value(), product=product, @@ -205,6 +206,8 @@ async def register(request: web.Request): expires_at=expires_at, ) + assert user is not None # nosec + # setup user groups assert ( # nosec product.name == invitation.product @@ -267,7 +270,7 @@ async def register(request: web.Request): ) ) - await db.delete_confirmation_and_user(user, _confirmation) + await db.delete_confirmation_and_user(user["id"], _confirmation) raise web.HTTPServiceUnavailable(text=user_error_msg) from err diff --git a/services/web/server/src/simcore_service_webserver/login/_controller/rest/twofa.py b/services/web/server/src/simcore_service_webserver/login/_controller/rest/twofa.py index 018e23ee1c05..62bce235f1d0 100644 --- a/services/web/server/src/simcore_service_webserver/login/_controller/rest/twofa.py +++ b/services/web/server/src/simcore_service_webserver/login/_controller/rest/twofa.py @@ -2,16 +2,15 @@ from aiohttp import web from aiohttp.web import RouteTableDef +from common_library.user_messages import user_message from servicelib.aiohttp import status from servicelib.aiohttp.requests_validation import parse_request_body_as -from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON from ....products import products_web from ....products.models import Product from ....session.access_policies import session_access_required from ....web_utils import envelope_response -from ... import _twofa_service -from ..._login_repository_legacy import AsyncpgStorage, get_plugin_storage +from ... import _auth_service, _twofa_service from ...constants import ( CODE_2FA_EMAIL_CODE_REQUIRED, CODE_2FA_SMS_CODE_REQUIRED, @@ -41,19 +40,15 @@ async def resend_2fa_code(request: web.Request): settings: LoginSettingsForProduct = get_plugin_settings( request.app, product_name=product.name ) - db: AsyncpgStorage = get_plugin_storage(request.app) resend_2fa_ = await parse_request_body_as(Resend2faBody, request) - user = await db.get_user({"email": resend_2fa_.email}) + user = await _auth_service.get_user_or_none(request.app, email=resend_2fa_.email) if not user: - raise web.HTTPUnauthorized( - text=MSG_UNKNOWN_EMAIL, content_type=MIMETYPE_APPLICATION_JSON - ) + raise web.HTTPUnauthorized(text=MSG_UNKNOWN_EMAIL) if not settings.LOGIN_2FA_REQUIRED: raise web.HTTPServiceUnavailable( - text="2FA login is not available", - content_type=MIMETYPE_APPLICATION_JSON, + text=user_message("2FA login is not available") ) # Already a code? @@ -77,8 +72,14 @@ async def resend_2fa_code(request: web.Request): # sends via SMS if resend_2fa_.via == "SMS": + user_phone_number = user.get("phone") + if not user_phone_number: + raise web.HTTPBadRequest( + text=user_message("User does not have a phone number registered") + ) + await _twofa_service.send_sms_code( - phone_number=user["phone"], + phone_number=user_phone_number, code=code, twilio_auth=settings.LOGIN_TWILIO, twilio_messaging_sid=product.twilio_messaging_sid, @@ -92,7 +93,7 @@ async def resend_2fa_code(request: web.Request): "name": CODE_2FA_SMS_CODE_REQUIRED, "parameters": { "message": MSG_2FA_CODE_SENT.format( - phone_number=_twofa_service.mask_phone_number(user["phone"]) + phone_number=_twofa_service.mask_phone_number(user_phone_number) ), "expiration_2fa": settings.LOGIN_2FA_CODE_EXPIRATION_SEC, }, diff --git a/services/web/server/src/simcore_service_webserver/login/_invitations_service.py b/services/web/server/src/simcore_service_webserver/login/_invitations_service.py index 0b99e459970f..3a8c5ceb13a2 100644 --- a/services/web/server/src/simcore_service_webserver/login/_invitations_service.py +++ b/services/web/server/src/simcore_service_webserver/login/_invitations_service.py @@ -40,7 +40,8 @@ InvitationsServiceUnavailableError, ) from ..products.models import Product -from . import _confirmation_service +from ..users import users_service +from . import _auth_service, _confirmation_service from ._login_repository_legacy import ( AsyncpgStorage, BaseConfirmationTokenDict, @@ -114,8 +115,9 @@ async def check_other_registrations( db: AsyncpgStorage, cfg: LoginOptions, ) -> None: + # An account is already registered with this email - if user := await db.get_user({"email": email}): + if user := await _auth_service.get_user_or_none(app, email=email): user_status = UserStatus(user["status"]) match user_status: @@ -143,10 +145,12 @@ async def check_other_registrations( ) if drop_previous_registration: if not _confirmation: - await db.delete_user(user=dict(user)) + await users_service.delete_user_without_projects( + app, user_id=user["id"], clean_cache=False + ) else: await db.delete_confirmation_and_user( - user=dict(user), confirmation=_confirmation + user_id=user["id"], confirmation=_confirmation ) _logger.warning( diff --git a/services/web/server/src/simcore_service_webserver/login/_login_repository_legacy.py b/services/web/server/src/simcore_service_webserver/login/_login_repository_legacy.py index d119c462d8b3..33be73a0fb1a 100644 --- a/services/web/server/src/simcore_service_webserver/login/_login_repository_legacy.py +++ b/services/web/server/src/simcore_service_webserver/login/_login_repository_legacy.py @@ -48,44 +48,6 @@ def __init__( self.user_tbl = user_table_name self.confirm_tbl = confirmation_table_name - # - # CRUD user - # - - async def get_user(self, with_data: dict[str, Any]) -> asyncpg.Record | None: - async with self.pool.acquire() as conn: - return await _login_repository_legacy_sql.find_one( - conn, self.user_tbl, with_data - ) - - async def create_user(self, data: dict[str, Any]) -> dict[str, Any]: - async with self.pool.acquire() as conn: - user_id = await _login_repository_legacy_sql.insert( - conn, self.user_tbl, data - ) - new_user = await _login_repository_legacy_sql.find_one( - conn, self.user_tbl, {"id": user_id} - ) - assert new_user # nosec - data.update( - id=new_user["id"], - created_at=new_user["created_at"], - primary_gid=new_user["primary_gid"], - ) - return data - - async def update_user(self, user: dict[str, Any], updates: dict[str, Any]) -> None: - async with self.pool.acquire() as conn: - await _login_repository_legacy_sql.update( - conn, self.user_tbl, {"id": user["id"]}, updates - ) - - async def delete_user(self, user: dict[str, Any]) -> None: - async with self.pool.acquire() as conn: - await _login_repository_legacy_sql.delete( - conn, self.user_tbl, {"id": user["id"]} - ) - # # CRUD confirmation # @@ -142,14 +104,14 @@ async def delete_confirmation(self, confirmation: ConfirmationTokenDict): # async def delete_confirmation_and_user( - self, user: dict[str, Any], confirmation: ConfirmationTokenDict + self, user_id: int, confirmation: ConfirmationTokenDict ): async with self.pool.acquire() as conn, conn.transaction(): await _login_repository_legacy_sql.delete( conn, self.confirm_tbl, {"code": confirmation["code"]} ) await _login_repository_legacy_sql.delete( - conn, self.user_tbl, {"id": user["id"]} + conn, self.user_tbl, {"id": user_id} ) async def delete_confirmation_and_update_user( diff --git a/services/web/server/src/simcore_service_webserver/login/_login_service.py b/services/web/server/src/simcore_service_webserver/login/_login_service.py index 3dd6364ff955..345153e55557 100644 --- a/services/web/server/src/simcore_service_webserver/login/_login_service.py +++ b/services/web/server/src/simcore_service_webserver/login/_login_service.py @@ -39,22 +39,19 @@ def _to_names(enum_cls, names) -> list[str]: ) -def validate_user_status(*, user: dict, support_email: str): +def validate_user_access(*, user_status: str, user_role: str, support_email: str): """ Raises: web.HTTPUnauthorized """ - assert "role" in user # nosec - - user_status: str = user["status"] if user_status == DELETED: raise web.HTTPUnauthorized( text=MSG_USER_DELETED.format(support_email=support_email), ) # 401 - if user_status == BANNED or user["role"] == ANONYMOUS: + if user_status == BANNED or user_role == ANONYMOUS: raise web.HTTPUnauthorized( text=MSG_USER_BANNED.format(support_email=support_email), ) # 401 diff --git a/services/web/server/src/simcore_service_webserver/login/_registration_service.py b/services/web/server/src/simcore_service_webserver/login/_registration_service.py index 25a0ad56cfcc..21b897327efd 100644 --- a/services/web/server/src/simcore_service_webserver/login/_registration_service.py +++ b/services/web/server/src/simcore_service_webserver/login/_registration_service.py @@ -1,2 +1,16 @@ +from aiohttp import web +from simcore_postgres_database.utils_users import UsersRepo + +from ..db.plugin import get_asyncpg_engine + + def get_user_name_from_email(email: str) -> str: return email.split("@")[0] + + +async def register_user_phone( + app: web.Application, *, user_id: int, user_phone: str +) -> None: + asyncpg_engine = get_asyncpg_engine(app) + repo = UsersRepo(asyncpg_engine) + await repo.update_user_phone(user_id=user_id, phone=user_phone) diff --git a/services/web/server/src/simcore_service_webserver/login/_security_service.py b/services/web/server/src/simcore_service_webserver/login/_security_service.py index dc5aa049cf90..82738d1ace5c 100644 --- a/services/web/server/src/simcore_service_webserver/login/_security_service.py +++ b/services/web/server/src/simcore_service_webserver/login/_security_service.py @@ -1,20 +1,20 @@ """Utils that extends on security_api plugin""" import logging -from typing import Any from aiohttp import web from servicelib.logging_utils import get_log_record_extra, log_context from ..security import security_web from ..web_utils import flash_response +from ._auth_service import UserInfoDict from .constants import MSG_LOGGED_IN _logger = logging.getLogger(__name__) async def login_granted_response( - request: web.Request, *, user: dict[str, Any] + request: web.Request, *, user: UserInfoDict ) -> web.Response: """ Grants authorization for user creating a responses with an auth cookie diff --git a/services/web/server/src/simcore_service_webserver/login/errors.py b/services/web/server/src/simcore_service_webserver/login/errors.py index 8db82ab55a00..c6512b2b23a3 100644 --- a/services/web/server/src/simcore_service_webserver/login/errors.py +++ b/services/web/server/src/simcore_service_webserver/login/errors.py @@ -10,3 +10,7 @@ class SendingVerificationSmsError(LoginError): class SendingVerificationEmailError(LoginError): msg_template = "Sending verification email failed. {reason}" + + +class WrongPasswordError(LoginError): + msg_template = "Invalid password provided" diff --git a/services/web/server/src/simcore_service_webserver/login_accounts/_controller_rest.py b/services/web/server/src/simcore_service_webserver/login_accounts/_controller_rest.py index c00156ec79ff..bdc46745ab5b 100644 --- a/services/web/server/src/simcore_service_webserver/login_accounts/_controller_rest.py +++ b/services/web/server/src/simcore_service_webserver/login_accounts/_controller_rest.py @@ -135,11 +135,13 @@ async def unregister_account(request: web.Request): credentials = await users_service.get_user_credentials( request.app, user_id=req_ctx.user_id ) - if body.email != credentials.email.lower() or not security_service.check_password( + if body.email != credentials.email or not security_service.check_password( body.password.get_secret_value(), credentials.password_hash ): raise web.HTTPConflict( - text="Wrong email or password. Please try again to delete this account" + text=user_message( + "Wrong email or password. Please try again to delete this account" + ) ) with log_context( diff --git a/services/web/server/src/simcore_service_webserver/publications/_rest.py b/services/web/server/src/simcore_service_webserver/publications/_rest.py index 6966a36baf22..e867c2042697 100644 --- a/services/web/server/src/simcore_service_webserver/publications/_rest.py +++ b/services/web/server/src/simcore_service_webserver/publications/_rest.py @@ -12,8 +12,8 @@ from .._meta import API_VTAG as VTAG from ..login._emails_service import AttachmentTuple, send_email_from_template, themed from ..login.decorators import login_required -from ..login.login_repository_legacy import AsyncpgStorage, get_plugin_storage from ..products import products_web +from ..users import users_service from ._utils import json2html _logger = logging.getLogger(__name__) @@ -57,11 +57,9 @@ async def service_submission(request: web.Request): support_email_address = product.support_email - db: AsyncpgStorage = get_plugin_storage(request.app) - user = await db.get_user({"id": request[RQT_USERID_KEY]}) - assert user # nosec - user_email = user.get("email") - assert user_email # nosec + user = await users_service.get_user_name_and_email( + request.app, user_id=request[RQT_USERID_KEY] + ) try: attachments = [ @@ -80,11 +78,11 @@ async def service_submission(request: web.Request): # send email await send_email_from_template( request, - from_=user_email, + from_=user.email, to=support_email_address, template=themed("templates/common", _EMAIL_TEMPLATE_NAME), context={ - "user": user_email, + "user": user.email, "data": json2html.convert( json=json_dumps(data), table_attributes='class="pure-table"' ), diff --git a/services/web/server/src/simcore_service_webserver/studies_dispatcher/_users.py b/services/web/server/src/simcore_service_webserver/studies_dispatcher/_users.py index 626c3a19752c..873274fe7125 100644 --- a/services/web/server/src/simcore_service_webserver/studies_dispatcher/_users.py +++ b/services/web/server/src/simcore_service_webserver/studies_dispatcher/_users.py @@ -17,18 +17,21 @@ import redis.asyncio as aioredis from aiohttp import web +from common_library.users_enums import UserRole, UserStatus from models_library.emails import LowerCaseEmailStr +from models_library.users import UserID from pydantic import BaseModel, TypeAdapter from redis.exceptions import LockNotOwnedError from servicelib.aiohttp.application_keys import APP_FIRE_AND_FORGET_TASKS_KEY from servicelib.logging_utils import log_decorator from servicelib.utils import fire_and_forget_task from servicelib.utils_secrets import generate_password +from simcore_postgres_database.utils_users import UsersRepo +from ..db.plugin import get_asyncpg_engine from ..garbage_collector.settings import GUEST_USER_RC_LOCK_FORMAT -from ..groups.api import auto_add_user_to_product_group -from ..login._login_service import ACTIVE, GUEST -from ..login.login_repository_legacy import AsyncpgStorage, get_plugin_storage +from ..groups import api as groups_service +from ..login._login_service import GUEST from ..products import products_web from ..redis import get_redis_lock_manager_client from ..security import security_service, security_web @@ -95,7 +98,6 @@ async def create_temporary_guest_user(request: web.Request): MaxGuestUsersError: No more guest users allowed """ - db: AsyncpgStorage = get_plugin_storage(request.app) redis_locks_client: aioredis.Redis = get_redis_lock_manager_client(request.app) settings: StudiesDispatcherSettings = get_plugin_settings(app=request.app) product_name = products_web.get_product_name(request) @@ -109,31 +111,33 @@ async def create_temporary_guest_user(request: web.Request): password = generate_password(length=12) expires_at = datetime.utcnow() + settings.STUDIES_GUEST_ACCOUNT_LIFETIME - usr = None + user_id: UserID | None = None + + repo = UsersRepo(get_asyncpg_engine(request.app)) + try: async with redis_locks_client.lock( GUEST_USER_RC_LOCK_FORMAT.format(user_id=random_user_name), timeout=MAX_DELAY_TO_CREATE_USER, ): # NOTE: usr Dict is incomplete, e.g. does not contain primary_gid - usr = await db.create_user( - { - "name": random_user_name, - "email": email, - "password_hash": security_service.encrypt_password(password), - "status": ACTIVE, - "role": GUEST, - "expires_at": expires_at, - } + user_row = await repo.new_user( + email=email, + password_hash=security_service.encrypt_password(password), + status=UserStatus.ACTIVE, + role=UserRole.GUEST, + expires_at=expires_at, ) - user = await users_service.get_user(request.app, usr["id"]) - await auto_add_user_to_product_group( - request.app, user_id=user["id"], product_name=product_name + user_id = user_row.id + + user = await users_service.get_user(request.app, user_id) + await groups_service.auto_add_user_to_product_group( + request.app, user_id=user_id, product_name=product_name ) # (2) read details above await redis_locks_client.lock( - GUEST_USER_RC_LOCK_FORMAT.format(user_id=user["id"]), + GUEST_USER_RC_LOCK_FORMAT.format(user_id=user_id), timeout=MAX_DELAY_TO_GUEST_FIRST_CONNECTION, ).acquire() @@ -146,14 +150,16 @@ async def create_temporary_guest_user(request: web.Request): # stop creating GUEST users. # NOTE: here we cleanup but if any trace is left it will be deleted by gc - if usr is not None and usr.get("id"): + if user_id: - async def _cleanup(draft_user): + async def _cleanup(): with suppress(Exception): - await db.delete_user(draft_user) + await users_service.delete_user_without_projects( + request.app, user_id=user_id, clean_cache=False + ) fire_and_forget_task( - _cleanup(usr), + _cleanup(), task_suffix_name="cleanup_temporary_guest_user", fire_and_forget_tasks_collection=request.app[ APP_FIRE_AND_FORGET_TASKS_KEY diff --git a/services/web/server/src/simcore_service_webserver/users/_users_repository.py b/services/web/server/src/simcore_service_webserver/users/_users_repository.py index 8b2f5ec4a085..dd17431498f7 100644 --- a/services/web/server/src/simcore_service_webserver/users/_users_repository.py +++ b/services/web/server/src/simcore_service_webserver/users/_users_repository.py @@ -372,13 +372,10 @@ async def get_user_billing_details( Raises: BillingDetailsNotFoundError """ - async with pass_or_acquire_connection(engine, connection) as conn: - query = UsersRepo.get_billing_details_query(user_id=user_id) - result = await conn.execute(query) - row = result.first() - if not row: - raise BillingDetailsNotFoundError(user_id=user_id) - return UserBillingDetails.model_validate(row) + row = await UsersRepo(engine).get_billing_details(connection, user_id=user_id) + if not row: + raise BillingDetailsNotFoundError(user_id=user_id) + return UserBillingDetails.model_validate(row) async def delete_user_by_id( diff --git a/services/web/server/src/simcore_service_webserver/users/_users_service.py b/services/web/server/src/simcore_service_webserver/users/_users_service.py index fa4a7f28f5b9..d23ce3fe38d0 100644 --- a/services/web/server/src/simcore_service_webserver/users/_users_service.py +++ b/services/web/server/src/simcore_service_webserver/users/_users_service.py @@ -16,6 +16,7 @@ from simcore_postgres_database.utils_groups_extra_properties import ( GroupExtraPropertiesNotFoundError, ) +from simcore_postgres_database.utils_users import UsersRepo from ..db.plugin import get_asyncpg_engine from ..security import security_service @@ -30,6 +31,7 @@ ) from .exceptions import ( MissingGroupExtraPropertiesForProductError, + UserNotFoundError, ) _logger = logging.getLogger(__name__) @@ -159,21 +161,19 @@ async def get_user_display_and_id_names( async def get_user_credentials( app: web.Application, *, user_id: UserID ) -> UserCredentialsTuple: - row = await _users_repository.get_user_or_raise( - get_asyncpg_engine(app), - user_id=user_id, - return_column_names=[ - "name", - "first_name", - "email", - "password_hash", - ], - ) + + repo = UsersRepo(get_asyncpg_engine(app)) + + user_row = await repo.get_user_by_id_or_none(user_id=user_id) + if user_row is None: + raise UserNotFoundError(user_id=user_id) + + user_password_hash = await repo.get_password_hash(user_id=user_id) return UserCredentialsTuple( - email=TypeAdapter(LowerCaseEmailStr).validate_python(row["email"]), - password_hash=row["password_hash"], - display_name=row["first_name"] or row["name"].capitalize(), + email=TypeAdapter(LowerCaseEmailStr).validate_python(user_row.email), + password_hash=user_password_hash, + display_name=user_row.first_name or user_row.name.capitalize(), ) @@ -213,7 +213,9 @@ async def get_user_invoice_address( # -async def delete_user_without_projects(app: web.Application, user_id: UserID) -> None: +async def delete_user_without_projects( + app: web.Application, *, user_id: UserID, clean_cache: bool = True +) -> None: """Deletes a user from the database if the user exists""" # WARNING: user cannot be deleted without deleting first all ist project # otherwise this function will raise asyncpg.exceptions.ForeignKeyViolationError @@ -228,9 +230,10 @@ async def delete_user_without_projects(app: web.Application, user_id: UserID) -> ) return - # This user might be cached in the auth. If so, any request - # with this user-id will get thru producing unexpected side-effects - await security_service.clean_auth_policy_cache(app) + if clean_cache: + # This user might be cached in the auth. If so, any request + # with this user-id will get thru producing unexpected side-effects + await security_service.clean_auth_policy_cache(app) async def set_user_as_deleted(app: web.Application, *, user_id: UserID) -> None: diff --git a/services/web/server/tests/unit/with_dbs/03/invitations/test_users_accounts_rest_registration.py b/services/web/server/tests/unit/with_dbs/03/invitations/test_users_accounts_rest_registration.py index c1964c4a46c9..4df5e0e49aae 100644 --- a/services/web/server/tests/unit/with_dbs/03/invitations/test_users_accounts_rest_registration.py +++ b/services/web/server/tests/unit/with_dbs/03/invitations/test_users_accounts_rest_registration.py @@ -252,7 +252,7 @@ async def test_search_and_pre_registration( got = UserAccountGet(**found[0], state=None) assert got.model_dump(include={"registered", "status"}) == { "registered": True, - "status": new_user["status"].name, + "status": new_user["status"], } @@ -327,6 +327,7 @@ async def test_list_users_accounts( status_upon_creation=UserStatus.ACTIVE, expires_at=None, ) + assert new_user["status"] == UserStatus.ACTIVE # 3. Test filtering by status # a. Check PENDING filter (should exclude the registered user) diff --git a/services/web/server/tests/unit/with_dbs/03/login/test_login_change_password.py b/services/web/server/tests/unit/with_dbs/03/login/test_login_change_password.py index 2653591ea902..9a3f422037be 100644 --- a/services/web/server/tests/unit/with_dbs/03/login/test_login_change_password.py +++ b/services/web/server/tests/unit/with_dbs/03/login/test_login_change_password.py @@ -34,7 +34,6 @@ async def test_unauthorized_to_change_password(client: TestClient, new_password: "confirm": new_password, }, ) - assert response.status == 401 await assert_status(response, status.HTTP_401_UNAUTHORIZED) @@ -54,7 +53,7 @@ async def test_wrong_current_password( }, ) assert response.url.path == url.path - assert response.status == 422 + assert response.status == status.HTTP_422_UNPROCESSABLE_ENTITY assert MSG_WRONG_PASSWORD in await response.text() await assert_status( response, status.HTTP_422_UNPROCESSABLE_ENTITY, MSG_WRONG_PASSWORD @@ -111,13 +110,13 @@ async def test_success(client: TestClient, new_password: str): }, ) assert response.url.path == url_change_password.path - assert response.status == 200 + assert response.status == status.HTTP_200_OK assert MSG_PASSWORD_CHANGED in await response.text() await assert_status(response, status.HTTP_200_OK, MSG_PASSWORD_CHANGED) # logout response = await client.post(f"{url_logout}") - assert response.status == 200 + assert response.status == status.HTTP_200_OK assert response.url.path == url_logout.path # login with new password @@ -128,6 +127,6 @@ async def test_success(client: TestClient, new_password: str): "password": new_password, }, ) - assert response.status == 200 + assert response.status == status.HTTP_200_OK assert response.url.path == url_login.path await assert_status(response, status.HTTP_200_OK, MSG_LOGGED_IN) diff --git a/services/web/server/tests/unit/with_dbs/03/login/test_login_logout.py b/services/web/server/tests/unit/with_dbs/03/login/test_login_logout.py index 13aa95c32e41..52ab61ed19ed 100644 --- a/services/web/server/tests/unit/with_dbs/03/login/test_login_logout.py +++ b/services/web/server/tests/unit/with_dbs/03/login/test_login_logout.py @@ -15,8 +15,7 @@ async def test_logout(client: TestClient, db: AsyncpgStorage): logout_url = client.app.router["auth_logout"].url_for() protected_url = client.app.router["get_my_profile"].url_for() - async with LoggedUser(client) as user: - + async with LoggedUser(client): # try to access protected page response = await client.get(f"{protected_url}") assert response.url.path == protected_url.path @@ -31,5 +30,3 @@ async def test_logout(client: TestClient, db: AsyncpgStorage): response = await client.get(f"{protected_url}") assert response.url.path == protected_url.path await assert_status(response, status.HTTP_401_UNAUTHORIZED) - - await db.delete_user(user) diff --git a/services/web/server/tests/unit/with_dbs/03/login/test_login_registration.py b/services/web/server/tests/unit/with_dbs/03/login/test_login_registration.py index 31dd881fa303..d54213cb29b2 100644 --- a/services/web/server/tests/unit/with_dbs/03/login/test_login_registration.py +++ b/services/web/server/tests/unit/with_dbs/03/login/test_login_registration.py @@ -19,6 +19,7 @@ from servicelib.rest_responses import unwrap_envelope from simcore_service_webserver.db.models import UserStatus from simcore_service_webserver.groups.api import auto_add_user_to_product_group +from simcore_service_webserver.login import _auth_service from simcore_service_webserver.login._confirmation_web import _url_for_confirmation from simcore_service_webserver.login._login_repository_legacy import AsyncpgStorage from simcore_service_webserver.login.constants import ( @@ -263,7 +264,6 @@ async def test_registration_with_invalid_confirmation_code( async def test_registration_without_confirmation( client: TestClient, - db: AsyncpgStorage, mocker: MockerFixture, user_email: str, user_password: str, @@ -293,13 +293,12 @@ async def test_registration_without_confirmation( data, _ = await assert_status(response, status.HTTP_200_OK) assert MSG_LOGGED_IN in data["message"] - user = await db.get_user({"email": user_email}) + user = await _auth_service.get_user_or_none(client.app, email=user_email) assert user async def test_registration_with_confirmation( client: TestClient, - db: AsyncpgStorage, capsys: pytest.CaptureFixture, mocker: MockerFixture, user_email: str, @@ -331,7 +330,8 @@ async def test_registration_with_confirmation( data, error = unwrap_envelope(await response.json()) assert response.status == 200, (data, error) - user = await db.get_user({"email": user_email}) + user = await _auth_service.get_user_or_none(client.app, email=user_email) + assert user assert user["status"] == UserStatus.CONFIRMATION_PENDING.name assert "verification link" in data["message"] @@ -350,7 +350,8 @@ async def test_registration_with_confirmation( assert response.status == 200 # user is active - user = await db.get_user({"email": user_email}) + user = await _auth_service.get_user_or_none(client.app, email=user_email) + assert user assert user["status"] == UserStatus.ACTIVE.name diff --git a/services/web/server/tests/unit/with_dbs/03/login/test_login_twofa.py b/services/web/server/tests/unit/with_dbs/03/login/test_login_twofa.py index 69cbc3b2ccdc..b31b1fae3fcb 100644 --- a/services/web/server/tests/unit/with_dbs/03/login/test_login_twofa.py +++ b/services/web/server/tests/unit/with_dbs/03/login/test_login_twofa.py @@ -21,6 +21,7 @@ from simcore_postgres_database.models.products import ProductLoginSettingsDict, products from simcore_service_webserver.application_settings import ApplicationSettings from simcore_service_webserver.db.models import UserStatus +from simcore_service_webserver.login import _auth_service from simcore_service_webserver.login._login_repository_legacy import AsyncpgStorage from simcore_service_webserver.login._twofa_service import ( _do_create_2fa_code, @@ -161,7 +162,8 @@ def _get_confirmation_link_from_email(): assert response.status == status.HTTP_200_OK # check email+password registered - user = await db.get_user({"email": user_email}) + user = await _auth_service.get_user_or_none(client.app, email=user_email) + assert user assert user["status"] == UserStatus.ACTIVE.name assert user["phone"] is None @@ -195,7 +197,8 @@ def _get_confirmation_link_from_email(): assert phone == user_phone_number # check phone still NOT in db (TODO: should be in database and unconfirmed) - user = await db.get_user({"email": user_email}) + user = await _auth_service.get_user_or_none(client.app, email=user_email) + assert user assert user["status"] == UserStatus.ACTIVE.name assert user["phone"] is None @@ -211,7 +214,8 @@ def _get_confirmation_link_from_email(): ) await assert_status(response, status.HTTP_200_OK) # check user has phone confirmed - user = await db.get_user({"email": user_email}) + user = await _auth_service.get_user_or_none(client.app, email=user_email) + assert user assert user["status"] == UserStatus.ACTIVE.name assert user["phone"] == user_phone_number @@ -252,7 +256,8 @@ def _get_confirmation_link_from_email(): await assert_status(response, status.HTTP_200_OK) # assert users is successfully registered - user = await db.get_user({"email": user_email}) + user = await _auth_service.get_user_or_none(client.app, email=user_email) + assert user assert user["email"] == user_email assert user["phone"] == user_phone_number assert user["status"] == UserStatus.ACTIVE.value diff --git a/services/web/server/tests/unit/with_dbs/03/users/conftest.py b/services/web/server/tests/unit/with_dbs/03/users/conftest.py index 8911ab95d35d..2272c5bc9f62 100644 --- a/services/web/server/tests/unit/with_dbs/03/users/conftest.py +++ b/services/web/server/tests/unit/with_dbs/03/users/conftest.py @@ -11,7 +11,6 @@ import sqlalchemy as sa from aiohttp import web from aiohttp.test_utils import TestServer -from faker import Faker from pytest_simcore.helpers.typing_env import EnvVarsDict from servicelib.aiohttp.application import create_safe_application from simcore_postgres_database.models.users_details import ( @@ -70,24 +69,19 @@ async def pre_registration_details_db_cleanup( @pytest.fixture async def product_owner_user( - faker: Faker, asyncpg_engine: AsyncEngine, ) -> AsyncIterable[dict[str, Any]]: """A PO user in the database""" - from pytest_simcore.helpers.faker_factories import random_user - from pytest_simcore.helpers.postgres_tools import insert_and_get_row_lifespan - from simcore_postgres_database.models.users import UserRole, users + from pytest_simcore.helpers.postgres_users import ( + insert_and_get_user_and_secrets_lifespan, + ) + from simcore_postgres_database.models.users import UserRole - async with insert_and_get_row_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup + async with insert_and_get_user_and_secrets_lifespan( # pylint:disable=contextmanager-generator-missing-cleanup asyncpg_engine, - table=users, - values=random_user( - faker, - email="po-user@email.com", - name="po-user-fixture", - role=UserRole.PRODUCT_OWNER, - ), - pk_col=users.c.id, + email="po-user@email.com", + name="po-user-fixture", + role=UserRole.PRODUCT_OWNER, ) as record: yield record diff --git a/services/web/server/tests/unit/with_dbs/04/garbage_collector/test_resource_manager.py b/services/web/server/tests/unit/with_dbs/04/garbage_collector/test_resource_manager.py index 874182e86a43..13b5b1ca8658 100644 --- a/services/web/server/tests/unit/with_dbs/04/garbage_collector/test_resource_manager.py +++ b/services/web/server/tests/unit/with_dbs/04/garbage_collector/test_resource_manager.py @@ -673,7 +673,6 @@ async def test_interactive_services_removed_per_project( mocked_notification_system, socketio_client_factory: Callable, client_session_id_factory: Callable[[], str], - asyncpg_storage_system_mock, storage_subsystem_mock, # when guest user logs out garbage is collected expected_save_state: bool, open_project: Callable, diff --git a/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_studies_access.py b/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_studies_access.py index 9057497e0b5b..0647b39eb2a4 100644 --- a/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_studies_access.py +++ b/services/web/server/tests/unit/with_dbs/04/studies_dispatcher/test_studies_dispatcher_studies_access.py @@ -421,7 +421,7 @@ async def enforce_garbage_collect_guest(uid): ) await delete_task - await delete_user_without_projects(app, uid) + await delete_user_without_projects(app, user_id=uid) return uid user_id = await enforce_garbage_collect_guest(uid=data["id"]) diff --git a/services/web/server/tests/unit/with_dbs/conftest.py b/services/web/server/tests/unit/with_dbs/conftest.py index 539006a5d957..5705c4b95ca9 100644 --- a/services/web/server/tests/unit/with_dbs/conftest.py +++ b/services/web/server/tests/unit/with_dbs/conftest.py @@ -420,14 +420,6 @@ async def _mock_result() -> None: return MockedStorageSubsystem(mock, mock1, mock2, mock3) -@pytest.fixture -def asyncpg_storage_system_mock(mocker): - return mocker.patch( - "simcore_service_webserver.login._login_repository_legacy.AsyncpgStorage.delete_user", - return_value="", - ) - - @pytest.fixture async def mocked_dynamic_services_interface( mocker: MockerFixture,