11import logging
2- from typing import Any , Literal , TypeAlias , cast
2+ from typing import Annotated , Any , Literal , TypeAlias , cast
33
44import sqlalchemy as sa
5+ from annotated_types import doc
56from common_library .exclude import Unset , is_unset
67from common_library .users_enums import AccountRequestStatus
78from models_library .products import ProductName
@@ -437,11 +438,21 @@ async def list_merged_pre_and_registered_users(
437438 connection : AsyncConnection | None = None ,
438439 * ,
439440 product_name : ProductName ,
440- filter_any_account_request_status : list [AccountRequestStatus ] | None = None ,
441+ filter_any_account_request_status : Annotated [
442+ list [AccountRequestStatus ] | None ,
443+ doc (
444+ "If provided, only returns users with account request status in this list (only pre-registered users with any of these statuses will be included)"
445+ ),
446+ ] = None ,
441447 filter_include_deleted : bool = False ,
442448 pagination_limit : int = 50 ,
443449 pagination_offset : int = 0 ,
444- order_by : list [tuple [OrderKeys , OrderDirs ]] | None = None ,
450+ order_by : Annotated [
451+ list [tuple [OrderKeys , OrderDirs ]] | None ,
452+ doc (
453+ 'Valid fields: "email", "current_status_created". Default: [("email", "asc"), ("is_pre_registered", "desc"), ("current_status_created", "desc")]'
454+ ),
455+ ] = None ,
445456) -> tuple [list [dict [str , Any ]], int ]:
446457 """Retrieves and merges users from both users and pre-registration tables.
447458
@@ -450,18 +461,6 @@ async def list_merged_pre_and_registered_users(
450461 2. Users who are pre-registered (in users_pre_registration_details table)
451462 3. Users who are both registered and pre-registered
452463
453- Args:
454- engine: Database engine
455- connection: Optional existing connection
456- product_name: Product name to filter by
457- filter_any_account_request_status: If provided, only returns users with account request status in this list
458- (only pre-registered users with any of these statuses will be included)
459- filter_include_deleted: Whether to include deleted users
460- pagination_limit: Maximum number of results to return
461- pagination_offset: Number of results to skip (for pagination)
462- order_by: List of (field, direction) tuples. Valid fields: "email", "current_status_created"
463- Default: [("email", "asc"), ("is_pre_registered", "desc"), ("current_status_created", "desc")]
464-
465464 Returns:
466465 Tuple of (list of merged user data, total count)
467466 """
@@ -579,22 +578,41 @@ async def list_merged_pre_and_registered_users(
579578 else :
580579 merged_query = pre_registered_users_query .union_all (registered_users_query )
581580
582- # Add distinct on email to eliminate duplicates
581+ # Add distinct on email to eliminate duplicates using ROW_NUMBER()
583582 merged_query_subq = merged_query .subquery ()
584583
584+ # Use ROW_NUMBER() to prioritize records per email
585+ # This allows us to order by any field without DISTINCT ON constraints
586+ ranked_query = sa .select (
587+ merged_query_subq ,
588+ sa .func .row_number ()
589+ .over (
590+ partition_by = merged_query_subq .c .email ,
591+ order_by = [
592+ merged_query_subq .c .is_pre_registered .desc (), # Prioritize pre-registered
593+ merged_query_subq .c .current_status_created .desc (), # Then by most recent
594+ ],
595+ )
596+ .label ("rn" ),
597+ ).subquery ()
598+
599+ # Filter to get only the first record per email (rn = 1)
600+ filtered_query = sa .select (
601+ * [col for col in ranked_query .c if col .name != "rn" ]
602+ ).where (ranked_query .c .rn == 1 )
603+
585604 # Build ordering clauses using the extracted function
586- order_by_clauses = _build_ordering_clauses (merged_query_subq , order_by )
605+ order_by_clauses = _build_ordering_clauses_for_filtered_query (
606+ filtered_query , order_by
607+ )
587608
588- distinct_query = (
589- sa .select (merged_query_subq )
590- .select_from (merged_query_subq )
591- .distinct (merged_query_subq .c .email )
592- .order_by (* order_by_clauses )
609+ final_query = (
610+ filtered_query .order_by (* order_by_clauses )
593611 .limit (pagination_limit )
594612 .offset (pagination_offset )
595613 )
596614
597- # Count query (for pagination)
615+ # Count query (for pagination) - count distinct emails
598616 count_query = sa .select (sa .func .count ().label ("total" )).select_from (
599617 sa .select (merged_query_subq .c .email )
600618 .select_from (merged_query_subq )
@@ -608,25 +626,17 @@ async def list_merged_pre_and_registered_users(
608626 total_count = count_result .scalar_one ()
609627
610628 # Get user records
611- result = await conn .execute (distinct_query )
629+ result = await conn .execute (final_query )
612630 records = result .mappings ().all ()
613631
614632 return cast (list [dict [str , Any ]], records ), total_count
615633
616634
617- def _build_ordering_clauses (
618- merged_query_subq : sa .sql .Subquery ,
635+ def _build_ordering_clauses_for_filtered_query (
636+ query : sa .sql .Select ,
619637 order_by : list [tuple [OrderKeys , OrderDirs ]] | None = None ,
620638) -> list [sa .sql .ColumnElement ]:
621- """Build ORDER BY clauses for merged user queries.
622-
623- Args:
624- merged_query_subq: The merged query subquery containing all columns
625- order_by: List of (field, direction) tuples for ordering
626-
627- Returns:
628- List of SQLAlchemy ordering clauses
629- """
639+ """Build ORDER BY clauses for filtered query (no DISTINCT ON constraints)."""
630640 _ordering_criteria : list [tuple [str , OrderDirs ]] = []
631641
632642 if order_by is None :
@@ -644,7 +654,8 @@ def _build_ordering_clauses(
644654
645655 order_by_clauses = []
646656 for field , direction in _ordering_criteria :
647- column = merged_query_subq .columns [field ]
657+ # Get column from the query's selected columns
658+ column = next (col for col in query .selected_columns if col .name == field )
648659 if direction == "asc" :
649660 order_by_clauses .append (column .asc ())
650661 else :
0 commit comments