Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,22 +1,36 @@
from types import UnionType
from typing import Any, Literal, get_args, get_origin
from typing import Annotated, Any, Literal, Union, get_args, get_origin

from pydantic.fields import FieldInfo

NoneType: type = type(None)


def get_type(info: FieldInfo) -> Any:
field_type = info.annotation
if args := get_args(info.annotation):
field_type = next(a for a in args if a is not type(None))
field_type = next(a for a in args if a is not NoneType)
return field_type


def _unwrap_annotation(ann):
"""Peel off Annotated wrappers until reaching the core type."""
while get_origin(ann) is Annotated:
ann = get_args(ann)[0]
return ann


def is_literal(info: FieldInfo) -> bool:
return get_origin(info.annotation) is Literal
ann = _unwrap_annotation(info.annotation)
return get_origin(ann) is Literal


def is_nullable(info: FieldInfo) -> bool:
origin = get_origin(info.annotation) # X | None or Optional[X] will return Union
if origin is UnionType:
return any(x in get_args(info.annotation) for x in (type(None), Any))
return False
"""Checks whether a field allows None as a value."""
ann = _unwrap_annotation(info.annotation)
origin = get_origin(ann) # X | None or Optional[X] will return Union

if origin in (Union, UnionType):
return any(arg is NoneType or arg is Any for arg in get_args(ann))

return ann is NoneType or ann is Any
16 changes: 14 additions & 2 deletions packages/common-library/tests/test_pydantic_fields_extension.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from collections.abc import Callable
from typing import Any, Literal
from typing import Annotated, Any, Literal

import pytest
from common_library.pydantic_fields_extension import get_type, is_literal, is_nullable
from pydantic import BaseModel
from pydantic import BaseModel, PositiveInt


class MyModel(BaseModel):
Expand All @@ -12,6 +12,11 @@ class MyModel(BaseModel):
c: str = "bla"
d: bool | None = None
e: Literal["bla"]
f: Annotated[
PositiveInt | None,
"nullable inside Annotated (PositiveInt = Annotated[int, ...])",
]
g: Annotated[Literal["foo", "bar"], "literal inside Annotated"]


@pytest.mark.parametrize(
Expand Down Expand Up @@ -50,6 +55,8 @@ class MyModel(BaseModel):
),
(is_literal, False, "d"),
(is_literal, True, "e"),
(is_literal, False, "f"),
(is_literal, True, "g"),
(
is_nullable,
False,
Expand All @@ -67,6 +74,11 @@ class MyModel(BaseModel):
),
(is_nullable, True, "d"),
(is_nullable, False, "e"),
(
is_nullable,
True,
"f",
),
],
)
def test_field_fn(fn: Callable[[Any], Any], expected: Any, name: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from ..basic_types import IDStr
from ..emails import LowerCaseEmailStr
from ..groups import AccessRightsDict, Group, GroupID, GroupsByTypeTuple
from ..groups import AccessRightsDict, Group, GroupID, GroupsByTypeTuple, PrimaryGroupID
from ..products import ProductName
from ..rest_base import RequestParameters
from ..users import (
Expand Down Expand Up @@ -381,14 +381,32 @@ class UserAccountGet(OutputSchema):

# user status
registered: bool
status: UserStatus | None
status: UserStatus | None = None
products: Annotated[
list[ProductName] | None,
Field(
description="List of products this users is included or None if fields is unset",
),
] = None

# user (if an account was created)
user_id: Annotated[
UserID | None,
Field(description="Unique identifier of the user if an account was created"),
] = None
user_name: Annotated[
UserNameID | None,
Field(description="Username of the user if an account was created"),
] = None
user_primary_group_id: Annotated[
PrimaryGroupID | None,
Field(
description="Primary group ID of the user if an account was created",
alias="groupId",
# SEE https://github.com/ITISFoundation/osparc-simcore/pull/8358#issuecomment-3279491740
),
] = None

@field_validator("status")
@classmethod
def _consistency_check(cls, v, info: ValidationInfo):
Expand Down
4 changes: 3 additions & 1 deletion packages/models-library/src/models_library/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
EVERYONE_GROUP_ID: Final[int] = 1

GroupID: TypeAlias = PositiveInt
PrimaryGroupID: TypeAlias = Annotated[GroupID, Field(gt=EVERYONE_GROUP_ID)]
StandardGroupID: TypeAlias = Annotated[GroupID, Field(gt=EVERYONE_GROUP_ID)]

__all__: tuple[str, ...] = ("GroupType",)


class Group(BaseModel):
gid: PositiveInt
gid: GroupID
name: str
description: str
group_type: Annotated[GroupType, Field(alias="type")]
Expand Down
2 changes: 1 addition & 1 deletion services/web/server/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.78.0
0.79.0
2 changes: 1 addition & 1 deletion services/web/server/setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.78.0
current_version = 0.79.0
commit = True
message = services/webserver api version: {current_version} → {new_version}
tag = False
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
openapi: 3.1.0
info:
title: simcore-service-webserver
description: Main service with an interface (http-API & websockets) to the web front-end
version: 0.78.0
version: 0.79.0
servers:
- url: ''
description: webserver
Expand Down Expand Up @@ -18429,6 +18429,30 @@
title: Products
description: List of products this users is included or None if fields is
unset
userId:
anyOf:
- type: integer
exclusiveMinimum: true
minimum: 0
- type: 'null'
title: Userid
description: Unique identifier of the user if an account was created
userName:
anyOf:
- type: string
maxLength: 100
minLength: 1
- type: 'null'
title: Username
description: Username of the user if an account was created
groupId:
anyOf:
- type: integer
exclusiveMinimum: true
minimum: 1
- type: 'null'
title: Groupid
description: Primary group ID of the user if an account was created
type: object
required:
- firstName
Expand All @@ -18445,7 +18469,6 @@
- preRegistrationCreated
- accountRequestStatus
- registered
- status
title: UserAccountGet
UserAccountReject:
properties:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,19 +381,25 @@ async def search_merged_pre_and_registered_users(
users_pre_registration_details.c.state,
users_pre_registration_details.c.postal_code,
users_pre_registration_details.c.country,
users_pre_registration_details.c.user_id,
users_pre_registration_details.c.user_id.label("pre_reg_user_id"),
users_pre_registration_details.c.extras,
users_pre_registration_details.c.account_request_status,
users_pre_registration_details.c.account_request_reviewed_by,
users_pre_registration_details.c.account_request_reviewed_at,
users.c.status,
invited_by,
account_request_reviewed_by_username, # account_request_reviewed_by converted to username
users_pre_registration_details.c.created,
# NOTE: some users have no pre-registration details (e.g. s4l-lite)
users.c.id.label("user_id"), # real user_id from users table
users.c.name.label("user_name"),
users.c.primary_gid.label("user_primary_group_id"),
users.c.status,
)

left_outer_join = _build_left_outer_join_query(
filter_by_email_like, product_name, columns
filter_by_email_like,
product_name,
columns,
)
right_outer_join = _build_right_outer_join_query(
filter_by_email_like,
Expand Down Expand Up @@ -494,6 +500,7 @@ async def list_merged_pre_and_registered_users(
users_pre_registration_details.c.account_request_reviewed_at,
users.c.id.label("user_id"),
users.c.name.label("user_name"),
users.c.primary_gid.label("user_primary_group_id"),
users.c.status,
# Use created_by directly instead of a subquery
users_pre_registration_details.c.created_by.label("created_by"),
Expand Down Expand Up @@ -530,6 +537,7 @@ async def list_merged_pre_and_registered_users(
sa.literal(None).label("account_request_reviewed_at"),
users.c.id.label("user_id"),
users.c.name.label("user_name"),
users.c.primary_gid.label("user_primary_group_id"),
users.c.status,
# Match the created_by field from the pre_reg query
sa.literal(None).label("created_by"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ async def _list_products_or_none(user_id):
# NOTE: old users will not have extra details
registered=r.user_id is not None if r.pre_email else r.status is not None,
status=r.status,
# user
user_id=r.user_id,
user_name=r.user_name,
user_primary_group_id=r.user_primary_group_id,
)
for r in rows
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ async def test_search_and_pre_registration(
):
assert client.app

# NOTE: listing of user accounts drops nullable fields to avoid lengthy responses (even if they have no defaults)
# therefore they are reconstructed here from http response payloads
nullable_fields = {
name: None
for name, field in UserAccountGet.model_fields.items()
if is_nullable(field)
}

# ONLY in `users` and NOT `users_pre_registration_details`
resp = await client.get(
"/v0/admin/user-accounts:search", params={"email": logged_user["email"]}
Expand All @@ -240,12 +248,6 @@ async def test_search_and_pre_registration(
found, _ = await assert_status(resp, status.HTTP_200_OK)
assert len(found) == 1

nullable_fields = {
name: None
for name, field in UserAccountGet.model_fields.items()
if is_nullable(field)
}

got = UserAccountGet.model_validate({**nullable_fields, **found[0]})
expected = {
"first_name": logged_user.get("first_name"),
Expand All @@ -261,6 +263,9 @@ async def test_search_and_pre_registration(
"extras": {},
"registered": True,
"status": UserStatus.ACTIVE,
"user_id": logged_user["id"],
"user_name": logged_user["name"],
"user_primary_group_id": logged_user.get("primary_gid"),
}
assert got.model_dump(include=set(expected)) == expected

Expand All @@ -278,8 +283,8 @@ async def test_search_and_pre_registration(
)
found, _ = await assert_status(resp, status.HTTP_200_OK)
assert len(found) == 1
got = UserAccountGet(**found[0], state=None, status=None)

got = UserAccountGet.model_validate({**nullable_fields, **found[0]})
assert got.model_dump(include={"registered", "status"}) == {
"registered": False,
"status": None,
Expand All @@ -302,7 +307,8 @@ async def test_search_and_pre_registration(
)
found, _ = await assert_status(resp, status.HTTP_200_OK)
assert len(found) == 1
got = UserAccountGet(**found[0], state=None)

got = UserAccountGet.model_validate({**nullable_fields, **found[0]})
assert got.model_dump(include={"registered", "status"}) == {
"registered": True,
"status": new_user["status"],
Expand Down
Loading