11from abc import ABC
22from enum import Enum
33from functools import partial
4- from typing import TypeVar , Union , Generic , Type , Tuple , Iterable , Any , Mapping
5-
6- from sqlalchemy import asc , desc
7- from sqlalchemy .orm import object_mapper , class_mapper , Mapper
4+ from math import ceil
5+ from typing import (
6+ TypeVar ,
7+ Union ,
8+ Generic ,
9+ Type ,
10+ Tuple ,
11+ Iterable ,
12+ Any ,
13+ Mapping ,
14+ List ,
15+ Collection ,
16+ )
17+
18+ from pydantic .generics import GenericModel
19+ from sqlalchemy import asc , desc , select , func
20+ from sqlalchemy .orm import object_mapper , class_mapper , Mapper , lazyload
821from sqlalchemy .orm .exc import UnmappedInstanceError
922from sqlalchemy .sql import Select
1023
@@ -19,9 +32,30 @@ class SortDirection(Enum):
1932 DESC = partial (desc )
2033
2134
35+ class PaginatedResult (GenericModel , Generic [MODEL ]):
36+ items : List [MODEL ]
37+ page : int
38+ per_page : int
39+ total_pages : int
40+ total_items : int
41+
42+
2243class BaseRepository (Generic [MODEL ], ABC ):
44+ _max_query_limit : int = 50
2345 _model : Type [MODEL ]
2446
47+ def __init__ (self , model_class : Union [Type [MODEL ], None ] = None ) -> None :
48+ if getattr (self , "_model" , None ) is None and model_class is not None :
49+ self ._model = model_class
50+
51+ if getattr (self , "_model" , None ) is None or not self ._is_mapped_object (
52+ self ._model ()
53+ ):
54+ raise InvalidModel (
55+ "You need to supply a valid model class either in the `model_class` parameter"
56+ " or in the `_model` class property."
57+ )
58+
2559 def _is_mapped_object (self , obj : object ) -> bool :
2660 """Checks if the object is handled by the repository and is mapped in SQLAlchemy.
2761
@@ -65,7 +99,6 @@ def _filter_select(self, stmt: Select, search_params: Mapping[str, Any]) -> Sele
6599 :param search_params: Any keyword argument to be used as equality filter
66100 :return: The filtered query
67101 """
68- # TODO: Add support for offset/limit
69102 # TODO: Add support for relationship eager load
70103 for k , v in search_params .items ():
71104 """
@@ -100,3 +133,75 @@ def _filter_order_by(
100133 stmt = stmt .order_by (value [1 ].value (getattr (self ._model , value [0 ])))
101134
102135 return stmt
136+
137+ def _find_query (
138+ self ,
139+ search_params : Union [None , Mapping [str , Any ]] = None ,
140+ order_by : Union [None , Iterable [Union [str , Tuple [str , SortDirection ]]]] = None ,
141+ ) -> Select :
142+ stmt = select (self ._model )
143+
144+ if search_params :
145+ stmt = self ._filter_select (stmt , search_params )
146+ if order_by is not None :
147+ stmt = self ._filter_order_by (stmt , order_by )
148+
149+ return stmt
150+
151+ def _count_query (
152+ self ,
153+ query : Select ,
154+ ) -> Select :
155+ return select (func .count ()).select_from (
156+ query .options (lazyload ("*" )).order_by (None ).subquery () # type: ignore
157+ )
158+
159+ def _paginate_query (
160+ self ,
161+ stmt : Select ,
162+ page : int ,
163+ per_page : int ,
164+ ) -> Select :
165+ """Build the query offset and limit clauses from submitted parameters.
166+
167+ :param stmt: a Select statement
168+ :type stmt: Select
169+ :param page: Number of models to skip
170+ :type page: int
171+ :param per_page: Number of models to retrieve
172+ :type per_page: int
173+ :return: The filtered query
174+ """
175+
176+ _offset = max ((page - 1 ) * per_page , 0 )
177+ if _offset > 0 :
178+ stmt = stmt .offset (_offset )
179+
180+ _limit = max (min (per_page , self ._max_query_limit ), 0 )
181+ stmt = stmt .limit (_limit )
182+
183+ return stmt
184+
185+ def _build_paginated_result (
186+ self ,
187+ result_items : Collection [MODEL ],
188+ total_items_count : int ,
189+ page : int ,
190+ per_page : int ,
191+ ) -> PaginatedResult :
192+
193+ _per_page = max (min (per_page , self ._max_query_limit ), 0 )
194+ total_pages = (
195+ 0
196+ if total_items_count == 0 or total_items_count is None
197+ else ceil (total_items_count / _per_page )
198+ )
199+ _page = 0 if len (result_items ) == 0 else min (page , total_pages )
200+
201+ return PaginatedResult (
202+ items = result_items ,
203+ page = _page ,
204+ per_page = _per_page ,
205+ total_items = total_items_count ,
206+ total_pages = total_pages ,
207+ )
0 commit comments