Skip to content

Commit afc5181

Browse files
authored
Merge pull request #22 from febus982/repository_pagination
Pagination and protocols
2 parents 878482f + 6dad316 commit afc5181

File tree

13 files changed

+666
-123
lines changed

13 files changed

+666
-123
lines changed

.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ parallel = true
88
exclude_lines =
99
pragma: no cover
1010
pass
11+
\.\.\.

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ The classes provide some common use methods:
198198
* `save_many`: Persist multiple models in a single transaction
199199
* `delete`: Delete a model
200200
* `find`: Search for a list of models (basically an adapter for SELECT queries)
201+
* `paginated_find`: Search for a list of models, with pagination support
201202

202203
### Session lifecycle in repositories
203204

sqlalchemy_bind_manager/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
SQLAlchemyRepository,
88
SQLAlchemyAsyncRepository,
99
SortDirection,
10+
PaginatedResult,
1011
)
1112
from ._unit_of_work import (
1213
UnitOfWork,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .sync import SQLAlchemyRepository
22
from .async_ import SQLAlchemyAsyncRepository
3-
from .common import SortDirection
3+
from .common import SortDirection, PaginatedResult

sqlalchemy_bind_manager/_repository/async_.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
from abc import ABC
22
from contextlib import asynccontextmanager
3-
from typing import Union, Generic, Tuple, Iterable, List, AsyncIterator, Any, Mapping
3+
from typing import (
4+
Union,
5+
Generic,
6+
Tuple,
7+
Iterable,
8+
List,
9+
AsyncIterator,
10+
Any,
11+
Mapping,
12+
Type,
13+
)
414

5-
from sqlalchemy import select
615
from sqlalchemy.ext.asyncio import AsyncSession
716

817
from .._bind_manager import SQLAlchemyAsyncBind
918
from .._transaction_handler import AsyncSessionHandler
1019
from ..exceptions import ModelNotFound, InvalidConfig
11-
from .common import MODEL, PRIMARY_KEY, SortDirection, BaseRepository
20+
from .common import MODEL, PRIMARY_KEY, SortDirection, BaseRepository, PaginatedResult
1221

1322

1423
class SQLAlchemyAsyncRepository(Generic[MODEL], BaseRepository[MODEL], ABC):
@@ -19,14 +28,17 @@ def __init__(
1928
self,
2029
bind: Union[SQLAlchemyAsyncBind, None] = None,
2130
session: Union[AsyncSession, None] = None,
31+
model_class: Union[Type[MODEL], None] = None,
2232
) -> None:
2333
"""
2434
:param bind: A configured instance of SQLAlchemyAsyncBind
25-
:type bind: SQLAlchemyAsyncBind
35+
:type bind: Union[SQLAlchemyAsyncBind, None]
2636
:param session: An externally managed session
27-
:type session: AsyncSession
37+
:type session: Union[AsyncSession, None]
38+
:param model_class: A mapped SQLAlchemy model
39+
:type model_class: Union[Type[MODEL], None]
2840
"""
29-
super().__init__()
41+
super().__init__(model_class=model_class)
3042
if not (bool(bind) ^ bool(session)):
3143
raise InvalidConfig("Either `bind` or `session` have to be used, not both")
3244
self._external_session = session
@@ -105,16 +117,55 @@ async def find(
105117
106118
:param order_by:
107119
:param search_params: A dictionary containing equality filters
120+
:param limit: Number of models to retrieve
121+
:type limit: int
122+
:param offset: Number of models to skip
123+
:type offset: int
108124
:return: A collection of models
109125
:rtype: List
110126
"""
111-
stmt = select(self._model)
112-
if search_params:
113-
stmt = self._filter_select(stmt, search_params)
114-
115-
if order_by is not None:
116-
stmt = self._filter_order_by(stmt, order_by)
127+
stmt = self._find_query(search_params, order_by)
117128

118129
async with self._get_session() as session:
119130
result = await session.execute(stmt)
120131
return [x for x in result.scalars()]
132+
133+
async def paginated_find(
134+
self,
135+
per_page: int,
136+
page: int,
137+
search_params: Union[None, Mapping[str, Any]] = None,
138+
order_by: Union[None, Iterable[Union[str, Tuple[str, SortDirection]]]] = None,
139+
) -> PaginatedResult[MODEL]:
140+
"""Find models using filters and pagination
141+
142+
E.g.
143+
find(name="John") finds all models with name = John
144+
145+
:param per_page: Number of models to retrieve
146+
:type per_page: int
147+
:param page: Page to retrieve
148+
:type page: int
149+
:param search_params: A dictionary containing equality filters
150+
:param order_by:
151+
:return: A collection of models
152+
:rtype: List
153+
"""
154+
155+
find_stmt = self._find_query(search_params, order_by)
156+
paginated_stmt = self._paginate_query(find_stmt, page, per_page)
157+
158+
async with self._get_session() as session:
159+
total_items_count = (
160+
await session.execute(self._count_query(find_stmt))
161+
).scalar() or 0
162+
result_items = [
163+
x for x in (await session.execute(paginated_stmt)).scalars()
164+
]
165+
166+
return self._build_paginated_result(
167+
result_items=result_items,
168+
total_items_count=total_items_count,
169+
page=page,
170+
per_page=per_page,
171+
)

sqlalchemy_bind_manager/_repository/common.py

Lines changed: 110 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
11
from abc import ABC
22
from enum import Enum
33
from functools import partial
4-
from typing import TypeVar, Union, Generic, Type, Tuple, Iterable, Any, Mapping
5-
6-
from sqlalchemy import asc, desc
7-
from sqlalchemy.orm import object_mapper, class_mapper, Mapper
4+
from math import ceil
5+
from typing import (
6+
TypeVar,
7+
Union,
8+
Generic,
9+
Type,
10+
Tuple,
11+
Iterable,
12+
Any,
13+
Mapping,
14+
List,
15+
Collection,
16+
)
17+
18+
from pydantic.generics import GenericModel
19+
from sqlalchemy import asc, desc, select, func
20+
from sqlalchemy.orm import object_mapper, class_mapper, Mapper, lazyload
821
from sqlalchemy.orm.exc import UnmappedInstanceError
922
from sqlalchemy.sql import Select
1023

@@ -19,9 +32,30 @@ class SortDirection(Enum):
1932
DESC = partial(desc)
2033

2134

35+
class PaginatedResult(GenericModel, Generic[MODEL]):
36+
items: List[MODEL]
37+
page: int
38+
per_page: int
39+
total_pages: int
40+
total_items: int
41+
42+
2243
class BaseRepository(Generic[MODEL], ABC):
44+
_max_query_limit: int = 50
2345
_model: Type[MODEL]
2446

47+
def __init__(self, model_class: Union[Type[MODEL], None] = None) -> None:
48+
if getattr(self, "_model", None) is None and model_class is not None:
49+
self._model = model_class
50+
51+
if getattr(self, "_model", None) is None or not self._is_mapped_object(
52+
self._model()
53+
):
54+
raise InvalidModel(
55+
"You need to supply a valid model class either in the `model_class` parameter"
56+
" or in the `_model` class property."
57+
)
58+
2559
def _is_mapped_object(self, obj: object) -> bool:
2660
"""Checks if the object is handled by the repository and is mapped in SQLAlchemy.
2761
@@ -65,7 +99,6 @@ def _filter_select(self, stmt: Select, search_params: Mapping[str, Any]) -> Sele
6599
:param search_params: Any keyword argument to be used as equality filter
66100
:return: The filtered query
67101
"""
68-
# TODO: Add support for offset/limit
69102
# TODO: Add support for relationship eager load
70103
for k, v in search_params.items():
71104
"""
@@ -100,3 +133,75 @@ def _filter_order_by(
100133
stmt = stmt.order_by(value[1].value(getattr(self._model, value[0])))
101134

102135
return stmt
136+
137+
def _find_query(
138+
self,
139+
search_params: Union[None, Mapping[str, Any]] = None,
140+
order_by: Union[None, Iterable[Union[str, Tuple[str, SortDirection]]]] = None,
141+
) -> Select:
142+
stmt = select(self._model)
143+
144+
if search_params:
145+
stmt = self._filter_select(stmt, search_params)
146+
if order_by is not None:
147+
stmt = self._filter_order_by(stmt, order_by)
148+
149+
return stmt
150+
151+
def _count_query(
152+
self,
153+
query: Select,
154+
) -> Select:
155+
return select(func.count()).select_from(
156+
query.options(lazyload("*")).order_by(None).subquery() # type: ignore
157+
)
158+
159+
def _paginate_query(
160+
self,
161+
stmt: Select,
162+
page: int,
163+
per_page: int,
164+
) -> Select:
165+
"""Build the query offset and limit clauses from submitted parameters.
166+
167+
:param stmt: a Select statement
168+
:type stmt: Select
169+
:param page: Number of models to skip
170+
:type page: int
171+
:param per_page: Number of models to retrieve
172+
:type per_page: int
173+
:return: The filtered query
174+
"""
175+
176+
_offset = max((page - 1) * per_page, 0)
177+
if _offset > 0:
178+
stmt = stmt.offset(_offset)
179+
180+
_limit = max(min(per_page, self._max_query_limit), 0)
181+
stmt = stmt.limit(_limit)
182+
183+
return stmt
184+
185+
def _build_paginated_result(
186+
self,
187+
result_items: Collection[MODEL],
188+
total_items_count: int,
189+
page: int,
190+
per_page: int,
191+
) -> PaginatedResult:
192+
193+
_per_page = max(min(per_page, self._max_query_limit), 0)
194+
total_pages = (
195+
0
196+
if total_items_count == 0 or total_items_count is None
197+
else ceil(total_items_count / _per_page)
198+
)
199+
_page = 0 if len(result_items) == 0 else min(page, total_pages)
200+
201+
return PaginatedResult(
202+
items=result_items,
203+
page=_page,
204+
per_page=_per_page,
205+
total_items=total_items_count,
206+
total_pages=total_pages,
207+
)

sqlalchemy_bind_manager/_repository/sync.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
from abc import ABC
22
from contextlib import contextmanager
3-
from typing import Union, Generic, Iterable, Tuple, List, Iterator, Any, Mapping
3+
from math import ceil
4+
from typing import (
5+
Union,
6+
Generic,
7+
Iterable,
8+
Tuple,
9+
List,
10+
Iterator,
11+
Any,
12+
Mapping,
13+
Type,
14+
)
415

5-
from sqlalchemy import select
616
from sqlalchemy.orm import Session
717

818
from .._bind_manager import SQLAlchemyBind
919
from .._transaction_handler import SessionHandler
1020
from ..exceptions import ModelNotFound, InvalidConfig
11-
from .common import MODEL, PRIMARY_KEY, SortDirection, BaseRepository
21+
from .common import MODEL, PRIMARY_KEY, SortDirection, BaseRepository, PaginatedResult
1222

1323

1424
class SQLAlchemyRepository(Generic[MODEL], BaseRepository[MODEL], ABC):
@@ -19,14 +29,17 @@ def __init__(
1929
self,
2030
bind: Union[SQLAlchemyBind, None] = None,
2131
session: Union[Session, None] = None,
32+
model_class: Union[Type[MODEL], None] = None,
2233
) -> None:
2334
"""
2435
:param bind: A configured instance of SQLAlchemyBind
25-
:type bind: SQLAlchemyBind
36+
:type bind: Union[SQLAlchemyBind, None]
2637
:param session: An externally managed session
27-
:type session: Session
38+
:type session: Union[Session, None]
39+
:param model_class: A mapped SQLAlchemy model
40+
:type model_class: Union[Type[MODEL], None]
2841
"""
29-
super().__init__()
42+
super().__init__(model_class=model_class)
3043
if not (bool(bind) ^ bool(session)):
3144
raise InvalidConfig("Either `bind` or `session` have to be used, not both")
3245
self._external_session = session
@@ -97,18 +110,51 @@ def find(
97110
E.g.
98111
find(name="John") finds all models with name = John
99112
100-
:param order_by:
101113
:param search_params: A dictionary containing equality filters
114+
:param order_by:
102115
:return: A collection of models
103116
:rtype: List
104117
"""
105-
stmt = select(self._model)
106-
if search_params:
107-
stmt = self._filter_select(stmt, search_params)
108-
109-
if order_by is not None:
110-
stmt = self._filter_order_by(stmt, order_by)
118+
stmt = self._find_query(search_params, order_by)
111119

112120
with self._get_session() as session:
113121
result = session.execute(stmt)
114122
return [x for x in result.scalars()]
123+
124+
def paginated_find(
125+
self,
126+
per_page: int,
127+
page: int,
128+
search_params: Union[None, Mapping[str, Any]] = None,
129+
order_by: Union[None, Iterable[Union[str, Tuple[str, SortDirection]]]] = None,
130+
) -> PaginatedResult[MODEL]:
131+
"""Find models using filters and pagination
132+
133+
E.g.
134+
find(name="John") finds all models with name = John
135+
136+
:param per_page: Number of models to retrieve
137+
:type per_page: int
138+
:param page: Page to retrieve
139+
:type page: int
140+
:param search_params: A dictionary containing equality filters
141+
:param order_by:
142+
:return: A collection of models
143+
:rtype: List
144+
"""
145+
146+
find_stmt = self._find_query(search_params, order_by)
147+
paginated_stmt = self._paginate_query(find_stmt, page, per_page)
148+
149+
with self._get_session() as session:
150+
total_items_count = (
151+
session.execute(self._count_query(find_stmt)).scalar() or 0
152+
)
153+
result_items = [x for x in session.execute(paginated_stmt).scalars()]
154+
155+
return self._build_paginated_result(
156+
result_items=result_items,
157+
total_items_count=total_items_count,
158+
page=page,
159+
per_page=per_page,
160+
)

0 commit comments

Comments
 (0)