Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Update api-keys uniqueness constraint

Revision ID: 7e92447558e0
Revises: 06eafd25d004
Create Date: 2025-09-12 09:56:45.164921+00:00

"""

from alembic import op

# revision identifiers, used by Alembic.
revision = "7e92447558e0"
down_revision = "06eafd25d004"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("display_name_userid_uniqueness", "api_keys", type_="unique")
op.create_unique_constraint(
"display_name_userid_product_name_uniqueness",
"api_keys",
["display_name", "user_id", "product_name"],
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(
"display_name_userid_product_name_uniqueness", "api_keys", type_="unique"
)
op.create_unique_constraint(
"display_name_userid_uniqueness", "api_keys", ["display_name", "user_id"]
)
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@
"If set to NULL then the key does not expire.",
),
sa.UniqueConstraint(
"display_name", "user_id", name="display_name_userid_uniqueness"
"display_name",
"user_id",
"product_name",
name="display_name_userid_product_name_uniqueness",
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def create_api_key(
expires_at=(sa.func.now() + expiration) if expiration else None,
)
.on_conflict_do_update(
index_elements=["user_id", "display_name"],
index_elements=["user_id", "display_name", "product_name"],
set_={
"api_key": api_key,
"api_secret": _hash_secret(api_secret),
Expand Down
125 changes: 71 additions & 54 deletions services/web/server/tests/unit/with_dbs/01/test_api_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
# pylint: disable=too-many-arguments

import asyncio
from collections.abc import AsyncIterable
from collections.abc import AsyncIterable, Awaitable, Callable
from datetime import timedelta
from http import HTTPStatus
from http.client import HTTPException

import pytest
import tenacity
Expand Down Expand Up @@ -39,68 +38,50 @@


@pytest.fixture
async def fake_user_api_keys(
async def fake_api_key_factory(
client: TestClient,
logged_user: UserInfoDict,
osparc_product_name: ProductName,
faker: Faker,
) -> AsyncIterable[list[ApiKey]]:
) -> AsyncIterable[Callable[..., Awaitable[ApiKey]]]:
assert client.app

api_keys: list[ApiKey] = [
await _repository.create_api_key(
created_keys: list[tuple[ApiKey, ProductName]] = []

async def _create(
*,
product_name: ProductName | None = None,
display_name: str | None = None,
expiration=None,
api_key: str | None = None,
api_secret: str | None = None,
) -> ApiKey:
final_product_name = product_name or osparc_product_name
final_display_name = display_name or faker.pystr()
final_api_key = api_key or faker.pystr()
final_api_secret = api_secret or faker.pystr()

created_key = await _repository.create_api_key(
client.app,
user_id=logged_user["id"],
product_name=osparc_product_name,
display_name=faker.pystr(),
expiration=None,
api_key=faker.pystr(),
api_secret=faker.pystr(),
product_name=final_product_name,
display_name=final_display_name,
expiration=expiration,
api_key=final_api_key,
api_secret=final_api_secret,
)
for _ in range(5)
]

yield api_keys

for api_key in api_keys:
await _repository.delete_api_key(
client.app,
api_key_id=api_key.id,
user_id=logged_user["id"],
product_name=osparc_product_name,
)


@pytest.fixture
async def fake_auto_api_keys(
client: TestClient,
logged_user: UserInfoDict,
osparc_product_name: ProductName,
faker: Faker,
) -> AsyncIterable[list[ApiKey]]:
assert client.app

api_keys: list[ApiKey] = [
await _repository.create_api_key(
client.app,
user_id=logged_user["id"],
product_name=osparc_product_name,
display_name=API_KEY_AUTOGENERATED_DISPLAY_NAME_PREFIX + faker.pystr(),
expiration=None,
api_key=API_KEY_AUTOGENERATED_KEY_PREFIX + faker.pystr(),
api_secret=faker.pystr(),
)
for _ in range(5)
]
created_keys.append((created_key, final_product_name))
return created_key

yield api_keys
yield _create

for api_key in api_keys:
for api_key, product_name in created_keys:
await _repository.delete_api_key(
client.app,
api_key_id=api_key.id,
user_id=logged_user["id"],
product_name=osparc_product_name,
product_name=product_name,
)


Expand All @@ -123,16 +104,18 @@ def _get_user_access_parametrizations(expected_authed_status_code):
async def test_list_api_keys(
disabled_setup_garbage_collector: MockType,
client: TestClient,
fake_user_api_keys: list[ApiKey],
fake_api_key_factory: Callable[..., Awaitable[ApiKey]],
logged_user: UserInfoDict,
user_role: UserRole,
expected: HTTPStatus,
):
fake_api_keys = [await fake_api_key_factory() for _ in range(10)]

resp = await client.get("/v0/auth/api-keys")
data, errors = await assert_status(resp, expected)

if not errors:
assert len(data) == len(fake_user_api_keys)
assert len(data) == len(fake_api_keys)


@pytest.mark.parametrize(
Expand All @@ -142,11 +125,20 @@ async def test_list_api_keys(
async def test_list_auto_api_keys(
disabled_setup_garbage_collector: MockType,
client: TestClient,
fake_auto_api_keys: list[ApiKey],
fake_api_key_factory: Callable[..., Awaitable[ApiKey]],
logged_user: UserInfoDict,
user_role: UserRole,
expected: HTTPStatus,
faker: Faker,
):
fake_auto_api_keys = [
await fake_api_key_factory(
api_key=API_KEY_AUTOGENERATED_KEY_PREFIX + faker.pystr(),
display_name=API_KEY_AUTOGENERATED_DISPLAY_NAME_PREFIX + faker.pystr(),
)
for _ in range(10)
]

resp = await client.get(
"/v0/auth/api-keys", params={"includeAutogenerated": "true"}
)
Expand Down Expand Up @@ -203,19 +195,44 @@ async def test_create_api_key(
async def test_delete_api_keys(
disabled_setup_garbage_collector: MockType,
client: TestClient,
fake_user_api_keys: list[ApiKey],
fake_api_key_factory: Callable[..., Awaitable[ApiKey]],
logged_user: UserInfoDict,
user_role: UserRole,
expected: HTTPStatus,
):
fake_api_keys = [await fake_api_key_factory() for _ in range(10)]

resp = await client.delete("/v0/auth/api-keys/0")
await assert_status(resp, expected)

for api_key in fake_user_api_keys:
for api_key in fake_api_keys:
resp = await client.delete(f"/v0/auth/api-keys/{api_key.id}")
await assert_status(resp, expected)


@pytest.mark.parametrize(
"user_role,expected",
_get_user_access_parametrizations(status.HTTP_200_OK),
)
async def test_create_api_keys_same_display_name_different_products(
disabled_setup_garbage_collector: MockType,
client: TestClient,
fake_api_key_factory: Callable[..., Awaitable[ApiKey]],
logged_user: UserInfoDict,
app_products_names: list[str],
user_role: UserRole,
expected: HTTPStatus,
):
display_name = "foo"

created_keys = [
await fake_api_key_factory(display_name=display_name, product_name=product_name)
for product_name in app_products_names
]

assert len(created_keys) == len(app_products_names)


EXPIRATION_WAIT_FACTOR = 1.2


Expand Down Expand Up @@ -285,7 +302,7 @@ async def test_get_not_existing_api_key(
client: TestClient,
logged_user: UserInfoDict,
user_role: UserRole,
expected: HTTPException,
expected: HTTPStatus,
):
resp = await client.get("/v0/auth/api-keys/42")
data, errors = await assert_status(resp, expected)
Expand Down
Loading