1- from sqlmodel . ext . asyncio . session import AsyncSession
2- from sqlmodel import select , func , SQLModel
3- from typing import Type , TypeVar , Sequence , Optional
1+ from sqlalchemy import Row , Select
2+ from sqlmodel import Session , select , func , SQLModel
3+ from typing import Dict , Type , TypeVar , Sequence , Optional
44from common .core .schemas import PaginationParams , PaginatedResponse
5+ from sqlmodel .sql .expression import SelectOfScalar
6+ from typing import Union , Any
57
68ModelT = TypeVar ('ModelT' , bound = SQLModel )
79
810class Paginator :
9- def __init__ (self , session : AsyncSession ):
11+ def __init__ (self , session : Session ):
1012 self .session = session
11-
13+ def _process_result_row (self , row : Row ) -> Dict [str , Any ]:
14+ result_dict = {}
15+ for item , key in zip (row , row ._fields ):
16+ if isinstance (item , SQLModel ):
17+ result_dict .update (item .dict ())
18+ else :
19+ result_dict [key ] = item
20+
21+ return result_dict
1222 async def paginate (
1323 self ,
14- model : Type [ModelT ],
24+ stmt : Union [ Select , SelectOfScalar , Type [ModelT ] ],
1525 page : int = 1 ,
1626 size : int = 20 ,
1727 order_by : Optional [str ] = None ,
1828 desc : bool = False ,
1929 ** filters
20- ) -> tuple [Sequence [ModelT ], int ]:
30+ ) -> tuple [Sequence [Any ], int ]:
2131 offset = (page - 1 ) * size
22- stmt = select (model )
32+ single_model : bool = False
33+ if isinstance (stmt , type ) and issubclass (stmt , SQLModel ):
34+ stmt = select (stmt )
35+ single_model = True
2336
37+ # 应用过滤条件
2438 for field , value in filters .items ():
2539 if value is not None :
26- stmt = stmt .where (getattr (model , field ) == value )
40+ # 处理关联模型的字段 (如 user.name)
41+ if '.' in field :
42+ related_model , related_field = field .split ('.' )
43+ # 这里需要根据实际关联关系调整
44+ stmt = stmt .where (getattr (getattr (stmt .selected_columns , related_model ), related_field ) == value )
45+ else :
46+ stmt = stmt .where (getattr (stmt .selected_columns , field ) == value )
2747
48+ # 应用排序
2849 if order_by :
29- column = getattr (model , order_by )
50+ if '.' in order_by :
51+ related_model , related_field = order_by .split ('.' )
52+ column = getattr (getattr (stmt .selected_columns , related_model ), related_field )
53+ else :
54+ column = getattr (stmt .selected_columns , order_by )
3055 stmt = stmt .order_by (column .desc () if desc else column .asc ())
3156
32- count_stmt = select (func .count ()).select_from (model )
33- for field , value in filters .items ():
34- if value is not None :
35- count_stmt = count_stmt .where (getattr (model , field ) == value )
36-
57+ # 计算总数
58+ """ count_stmt = stmt.with_only_columns(func.count(), maintain_column_froms=True)
3759 result = self.session.exec(count_stmt)
38- total = result .first ()
60+ total: int = result.first() """
61+ count_stmt = select (func .count ()).select_from (stmt .subquery ())
62+ total_result = self .session .exec (count_stmt )
63+ total : int = total_result .first ()
3964
65+ # 应用分页
4066 stmt = stmt .offset (offset ).limit (size )
4167
68+ # 执行查询
4269 result = self .session .exec (stmt )
43- items = result .all ()
44-
70+ if not single_model :
71+ items = [self ._process_result_row (row ) for row in result ]
72+ else :
73+ items = result .all ()
4574 return items , total
4675
4776 async def get_paginated_response (
4877 self ,
49- model : Type [ModelT ],
78+ stmt : Union [ Select , SelectOfScalar , Type [ModelT ] ],
5079 pagination : PaginationParams ,
5180 ** filters
52- ) -> PaginatedResponse [ModelT ]:
81+ ) -> PaginatedResponse [Any ]:
5382 items , total = await self .paginate (
54- model = model ,
83+ stmt = stmt ,
5584 page = pagination .page ,
5685 size = pagination .size ,
5786 order_by = pagination .order_by ,
@@ -61,7 +90,7 @@ async def get_paginated_response(
6190
6291 total_pages = (total + pagination .size - 1 ) // pagination .size
6392
64- return PaginatedResponse [ModelT ](
93+ return PaginatedResponse [Any ](
6594 items = items ,
6695 total = total ,
6796 page = pagination .page ,
0 commit comments