diff --git a/packages/postgres-database/src/simcore_postgres_database/migration/versions/7e92447558e0_update_api_keys_uniqueness_constraint.py b/packages/postgres-database/src/simcore_postgres_database/migration/versions/7e92447558e0_update_api_keys_uniqueness_constraint.py new file mode 100644 index 000000000000..e3a42a641256 --- /dev/null +++ b/packages/postgres-database/src/simcore_postgres_database/migration/versions/7e92447558e0_update_api_keys_uniqueness_constraint.py @@ -0,0 +1,37 @@ +"""Update api-keys uniqueness constraint + +Revision ID: 7e92447558e0 +Revises: 06eafd25d004 +Create Date: 2025-09-12 09:56:45.164921+00:00 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "7e92447558e0" +down_revision = "06eafd25d004" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("display_name_userid_uniqueness", "api_keys", type_="unique") + op.create_unique_constraint( + "display_name_userid_product_name_uniqueness", + "api_keys", + ["display_name", "user_id", "product_name"], + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint( + "display_name_userid_product_name_uniqueness", "api_keys", type_="unique" + ) + op.create_unique_constraint( + "display_name_userid_uniqueness", "api_keys", ["display_name", "user_id"] + ) + # ### end Alembic commands ### diff --git a/packages/postgres-database/src/simcore_postgres_database/models/api_keys.py b/packages/postgres-database/src/simcore_postgres_database/models/api_keys.py index 2c3f12eca3ab..02a2fc58bbc2 100644 --- a/packages/postgres-database/src/simcore_postgres_database/models/api_keys.py +++ b/packages/postgres-database/src/simcore_postgres_database/models/api_keys.py @@ -75,7 +75,10 @@ "If set to NULL then the key does not expire.", ), sa.UniqueConstraint( - "display_name", "user_id", name="display_name_userid_uniqueness" + "display_name", + "user_id", + "product_name", + name="display_name_userid_product_name_uniqueness", ), ) diff --git a/services/web/server/src/simcore_service_webserver/api_keys/_repository.py b/services/web/server/src/simcore_service_webserver/api_keys/_repository.py index d765d0a388b0..e787ca04f878 100644 --- a/services/web/server/src/simcore_service_webserver/api_keys/_repository.py +++ b/services/web/server/src/simcore_service_webserver/api_keys/_repository.py @@ -47,7 +47,7 @@ async def create_api_key( expires_at=(sa.func.now() + expiration) if expiration else None, ) .on_conflict_do_update( - index_elements=["user_id", "display_name"], + index_elements=["user_id", "display_name", "product_name"], set_={ "api_key": api_key, "api_secret": _hash_secret(api_secret), diff --git a/services/web/server/tests/unit/with_dbs/01/test_api_keys.py b/services/web/server/tests/unit/with_dbs/01/test_api_keys.py index 902942840f54..0f03fad59fb4 100644 --- a/services/web/server/tests/unit/with_dbs/01/test_api_keys.py +++ b/services/web/server/tests/unit/with_dbs/01/test_api_keys.py @@ -4,10 +4,9 @@ # pylint: disable=too-many-arguments import asyncio -from collections.abc import AsyncIterable +from collections.abc import AsyncIterable, Awaitable, Callable from datetime import timedelta from http import HTTPStatus -from http.client import HTTPException import pytest import tenacity @@ -39,68 +38,50 @@ @pytest.fixture -async def fake_user_api_keys( +async def fake_api_key_factory( client: TestClient, logged_user: UserInfoDict, osparc_product_name: ProductName, faker: Faker, -) -> AsyncIterable[list[ApiKey]]: +) -> AsyncIterable[Callable[..., Awaitable[ApiKey]]]: assert client.app - api_keys: list[ApiKey] = [ - await _repository.create_api_key( + created_keys: list[tuple[ApiKey, ProductName]] = [] + + async def _create( + *, + product_name: ProductName | None = None, + display_name: str | None = None, + expiration=None, + api_key: str | None = None, + api_secret: str | None = None, + ) -> ApiKey: + final_product_name = product_name or osparc_product_name + final_display_name = display_name or faker.pystr() + final_api_key = api_key or faker.pystr() + final_api_secret = api_secret or faker.pystr() + + created_key = await _repository.create_api_key( client.app, user_id=logged_user["id"], - product_name=osparc_product_name, - display_name=faker.pystr(), - expiration=None, - api_key=faker.pystr(), - api_secret=faker.pystr(), + product_name=final_product_name, + display_name=final_display_name, + expiration=expiration, + api_key=final_api_key, + api_secret=final_api_secret, ) - for _ in range(5) - ] - - yield api_keys - - for api_key in api_keys: - await _repository.delete_api_key( - client.app, - api_key_id=api_key.id, - user_id=logged_user["id"], - product_name=osparc_product_name, - ) - - -@pytest.fixture -async def fake_auto_api_keys( - client: TestClient, - logged_user: UserInfoDict, - osparc_product_name: ProductName, - faker: Faker, -) -> AsyncIterable[list[ApiKey]]: - assert client.app - api_keys: list[ApiKey] = [ - await _repository.create_api_key( - client.app, - user_id=logged_user["id"], - product_name=osparc_product_name, - display_name=API_KEY_AUTOGENERATED_DISPLAY_NAME_PREFIX + faker.pystr(), - expiration=None, - api_key=API_KEY_AUTOGENERATED_KEY_PREFIX + faker.pystr(), - api_secret=faker.pystr(), - ) - for _ in range(5) - ] + created_keys.append((created_key, final_product_name)) + return created_key - yield api_keys + yield _create - for api_key in api_keys: + for api_key, product_name in created_keys: await _repository.delete_api_key( client.app, api_key_id=api_key.id, user_id=logged_user["id"], - product_name=osparc_product_name, + product_name=product_name, ) @@ -123,16 +104,18 @@ def _get_user_access_parametrizations(expected_authed_status_code): async def test_list_api_keys( disabled_setup_garbage_collector: MockType, client: TestClient, - fake_user_api_keys: list[ApiKey], + fake_api_key_factory: Callable[..., Awaitable[ApiKey]], logged_user: UserInfoDict, user_role: UserRole, expected: HTTPStatus, ): + fake_api_keys = [await fake_api_key_factory() for _ in range(10)] + resp = await client.get("/v0/auth/api-keys") data, errors = await assert_status(resp, expected) if not errors: - assert len(data) == len(fake_user_api_keys) + assert len(data) == len(fake_api_keys) @pytest.mark.parametrize( @@ -142,11 +125,20 @@ async def test_list_api_keys( async def test_list_auto_api_keys( disabled_setup_garbage_collector: MockType, client: TestClient, - fake_auto_api_keys: list[ApiKey], + fake_api_key_factory: Callable[..., Awaitable[ApiKey]], logged_user: UserInfoDict, user_role: UserRole, expected: HTTPStatus, + faker: Faker, ): + fake_auto_api_keys = [ + await fake_api_key_factory( + api_key=API_KEY_AUTOGENERATED_KEY_PREFIX + faker.pystr(), + display_name=API_KEY_AUTOGENERATED_DISPLAY_NAME_PREFIX + faker.pystr(), + ) + for _ in range(10) + ] + resp = await client.get( "/v0/auth/api-keys", params={"includeAutogenerated": "true"} ) @@ -203,19 +195,44 @@ async def test_create_api_key( async def test_delete_api_keys( disabled_setup_garbage_collector: MockType, client: TestClient, - fake_user_api_keys: list[ApiKey], + fake_api_key_factory: Callable[..., Awaitable[ApiKey]], logged_user: UserInfoDict, user_role: UserRole, expected: HTTPStatus, ): + fake_api_keys = [await fake_api_key_factory() for _ in range(10)] + resp = await client.delete("/v0/auth/api-keys/0") await assert_status(resp, expected) - for api_key in fake_user_api_keys: + for api_key in fake_api_keys: resp = await client.delete(f"/v0/auth/api-keys/{api_key.id}") await assert_status(resp, expected) +@pytest.mark.parametrize( + "user_role,expected", + _get_user_access_parametrizations(status.HTTP_200_OK), +) +async def test_create_api_keys_same_display_name_different_products( + disabled_setup_garbage_collector: MockType, + client: TestClient, + fake_api_key_factory: Callable[..., Awaitable[ApiKey]], + logged_user: UserInfoDict, + app_products_names: list[str], + user_role: UserRole, + expected: HTTPStatus, +): + display_name = "foo" + + created_keys = [ + await fake_api_key_factory(display_name=display_name, product_name=product_name) + for product_name in app_products_names + ] + + assert len(created_keys) == len(app_products_names) + + EXPIRATION_WAIT_FACTOR = 1.2 @@ -285,7 +302,7 @@ async def test_get_not_existing_api_key( client: TestClient, logged_user: UserInfoDict, user_role: UserRole, - expected: HTTPException, + expected: HTTPStatus, ): resp = await client.get("/v0/auth/api-keys/42") data, errors = await assert_status(resp, expected)