Skip to content

Commit ac7c462

Browse files
committed
Enhance list_merged_pre_and_registered_users with detailed parameter documentation and validation for order_by field
1 parent e44b2a7 commit ac7c462

File tree

2 files changed

+89
-36
lines changed

2 files changed

+89
-36
lines changed

services/web/server/src/simcore_service_webserver/users/_accounts_repository.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
2-
from typing import Any, Literal, TypeAlias, cast
2+
from typing import Annotated, Any, Literal, TypeAlias, cast
33

44
import sqlalchemy as sa
5+
from annotated_types import doc
56
from common_library.exclude import Unset, is_unset
67
from common_library.users_enums import AccountRequestStatus
78
from models_library.products import ProductName
@@ -437,11 +438,21 @@ async def list_merged_pre_and_registered_users(
437438
connection: AsyncConnection | None = None,
438439
*,
439440
product_name: ProductName,
440-
filter_any_account_request_status: list[AccountRequestStatus] | None = None,
441+
filter_any_account_request_status: Annotated[
442+
list[AccountRequestStatus] | None,
443+
doc(
444+
"If provided, only returns users with account request status in this list (only pre-registered users with any of these statuses will be included)"
445+
),
446+
] = None,
441447
filter_include_deleted: bool = False,
442448
pagination_limit: int = 50,
443449
pagination_offset: int = 0,
444-
order_by: list[tuple[OrderKeys, OrderDirs]] | None = None,
450+
order_by: Annotated[
451+
list[tuple[OrderKeys, OrderDirs]] | None,
452+
doc(
453+
'Valid fields: "email", "current_status_created". Default: [("email", "asc"), ("is_pre_registered", "desc"), ("current_status_created", "desc")]'
454+
),
455+
] = None,
445456
) -> tuple[list[dict[str, Any]], int]:
446457
"""Retrieves and merges users from both users and pre-registration tables.
447458
@@ -450,18 +461,6 @@ async def list_merged_pre_and_registered_users(
450461
2. Users who are pre-registered (in users_pre_registration_details table)
451462
3. Users who are both registered and pre-registered
452463
453-
Args:
454-
engine: Database engine
455-
connection: Optional existing connection
456-
product_name: Product name to filter by
457-
filter_any_account_request_status: If provided, only returns users with account request status in this list
458-
(only pre-registered users with any of these statuses will be included)
459-
filter_include_deleted: Whether to include deleted users
460-
pagination_limit: Maximum number of results to return
461-
pagination_offset: Number of results to skip (for pagination)
462-
order_by: List of (field, direction) tuples. Valid fields: "email", "current_status_created"
463-
Default: [("email", "asc"), ("is_pre_registered", "desc"), ("current_status_created", "desc")]
464-
465464
Returns:
466465
Tuple of (list of merged user data, total count)
467466
"""
@@ -579,22 +578,41 @@ async def list_merged_pre_and_registered_users(
579578
else:
580579
merged_query = pre_registered_users_query.union_all(registered_users_query)
581580

582-
# Add distinct on email to eliminate duplicates
581+
# Add distinct on email to eliminate duplicates using ROW_NUMBER()
583582
merged_query_subq = merged_query.subquery()
584583

584+
# Use ROW_NUMBER() to prioritize records per email
585+
# This allows us to order by any field without DISTINCT ON constraints
586+
ranked_query = sa.select(
587+
merged_query_subq,
588+
sa.func.row_number()
589+
.over(
590+
partition_by=merged_query_subq.c.email,
591+
order_by=[
592+
merged_query_subq.c.is_pre_registered.desc(), # Prioritize pre-registered
593+
merged_query_subq.c.current_status_created.desc(), # Then by most recent
594+
],
595+
)
596+
.label("rn"),
597+
).subquery()
598+
599+
# Filter to get only the first record per email (rn = 1)
600+
filtered_query = sa.select(
601+
*[col for col in ranked_query.c if col.name != "rn"]
602+
).where(ranked_query.c.rn == 1)
603+
585604
# Build ordering clauses using the extracted function
586-
order_by_clauses = _build_ordering_clauses(merged_query_subq, order_by)
605+
order_by_clauses = _build_ordering_clauses_for_filtered_query(
606+
filtered_query, order_by
607+
)
587608

588-
distinct_query = (
589-
sa.select(merged_query_subq)
590-
.select_from(merged_query_subq)
591-
.distinct(merged_query_subq.c.email)
592-
.order_by(*order_by_clauses)
609+
final_query = (
610+
filtered_query.order_by(*order_by_clauses)
593611
.limit(pagination_limit)
594612
.offset(pagination_offset)
595613
)
596614

597-
# Count query (for pagination)
615+
# Count query (for pagination) - count distinct emails
598616
count_query = sa.select(sa.func.count().label("total")).select_from(
599617
sa.select(merged_query_subq.c.email)
600618
.select_from(merged_query_subq)
@@ -608,25 +626,17 @@ async def list_merged_pre_and_registered_users(
608626
total_count = count_result.scalar_one()
609627

610628
# Get user records
611-
result = await conn.execute(distinct_query)
629+
result = await conn.execute(final_query)
612630
records = result.mappings().all()
613631

614632
return cast(list[dict[str, Any]], records), total_count
615633

616634

617-
def _build_ordering_clauses(
618-
merged_query_subq: sa.sql.Subquery,
635+
def _build_ordering_clauses_for_filtered_query(
636+
query: sa.sql.Select,
619637
order_by: list[tuple[OrderKeys, OrderDirs]] | None = None,
620638
) -> list[sa.sql.ColumnElement]:
621-
"""Build ORDER BY clauses for merged user queries.
622-
623-
Args:
624-
merged_query_subq: The merged query subquery containing all columns
625-
order_by: List of (field, direction) tuples for ordering
626-
627-
Returns:
628-
List of SQLAlchemy ordering clauses
629-
"""
639+
"""Build ORDER BY clauses for filtered query (no DISTINCT ON constraints)."""
630640
_ordering_criteria: list[tuple[str, OrderDirs]] = []
631641

632642
if order_by is None:
@@ -644,7 +654,8 @@ def _build_ordering_clauses(
644654

645655
order_by_clauses = []
646656
for field, direction in _ordering_criteria:
647-
column = merged_query_subq.columns[field]
657+
# Get column from the query's selected columns
658+
column = next(col for col in query.selected_columns if col.name == field)
648659
if direction == "asc":
649660
order_by_clauses.append(column.asc())
650661
else:

services/web/server/tests/unit/with_dbs/03/users/test_users_accounts_repository.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from common_library.users_enums import AccountRequestStatus
1414
from models_library.products import ProductName
1515
from models_library.users import UserID
16+
from pydantic import ValidationError
1617
from simcore_postgres_database.models.users_details import (
1718
users_pre_registration_details,
1819
)
@@ -1122,3 +1123,44 @@ async def test_list_merged_users_custom_ordering_with_current_status_created(
11221123
assert (
11231124
current_status_dates == sorted_dates
11241125
), "Users should be ordered by current_status_created desc"
1126+
1127+
1128+
@pytest.mark.parametrize(
1129+
"invalid_order_by,expected_error_pattern",
1130+
[
1131+
# Invalid field name
1132+
(
1133+
[("invalid_field", "asc")],
1134+
r"Input should be 'email' or 'current_status_created'",
1135+
),
1136+
# Invalid direction
1137+
([("email", "invalid_direction")], r"Input should be 'asc' or 'desc'"),
1138+
# Multiple invalid values
1139+
(
1140+
[("invalid_field", "invalid_direction")],
1141+
r"Input should be 'email' or 'current_status_created'",
1142+
),
1143+
# Mixed valid and invalid
1144+
(
1145+
[("email", "asc"), ("invalid_field", "desc")],
1146+
r"Input should be 'email' or 'current_status_created'",
1147+
),
1148+
],
1149+
)
1150+
async def test_list_merged_users_invalid_order_by_validation(
1151+
app: web.Application,
1152+
product_name: ProductName,
1153+
invalid_order_by: list[tuple[str, str]],
1154+
expected_error_pattern: str,
1155+
):
1156+
"""Test that invalid order_by parameters are rejected by Pydantic validation."""
1157+
1158+
asyncpg_engine = get_asyncpg_engine(app)
1159+
1160+
# Act & Assert - Should raise ValidationError
1161+
with pytest.raises(ValidationError, match=expected_error_pattern):
1162+
await _accounts_repository.list_merged_pre_and_registered_users(
1163+
asyncpg_engine,
1164+
product_name=product_name,
1165+
order_by=invalid_order_by,
1166+
)

0 commit comments

Comments
 (0)