Skip to content

Commit eafa015

Browse files
committed
Remove Cursor encoding/decoding logic
1 parent 5b64aac commit eafa015

File tree

9 files changed

+95
-146
lines changed

9 files changed

+95
-146
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .async_ import SQLAlchemyAsyncRepository
22
from .base_repository import SortDirection
3-
from .common import Cursor, CursorPaginatedResult, PaginatedResult
3+
from .common import CursorPaginatedResult, CursorReference, PaginatedResult
44
from .sync import SQLAlchemyRepository

sqlalchemy_bind_manager/_repository/async_.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
BaseRepository,
2222
SortDirection,
2323
)
24-
from .common import MODEL, PRIMARY_KEY, Cursor, CursorPaginatedResult, PaginatedResult
24+
from .common import (
25+
MODEL,
26+
PRIMARY_KEY,
27+
CursorPaginatedResult,
28+
CursorReference,
29+
PaginatedResult,
30+
)
2531

2632

2733
class SQLAlchemyAsyncRepository(Generic[MODEL], BaseRepository[MODEL], ABC):
@@ -180,7 +186,7 @@ async def paginated_find(
180186
async def cursor_paginated_find(
181187
self,
182188
items_per_page: int,
183-
reference_cursor: Union[Cursor, str, None] = None,
189+
cursor_reference: Union[CursorReference, None] = None,
184190
is_end_cursor: bool = False,
185191
search_params: Union[None, Mapping[str, Any]] = None,
186192
) -> CursorPaginatedResult[MODEL]:
@@ -201,13 +207,10 @@ async def cursor_paginated_find(
201207
:return: A collection of models
202208
:rtype: List
203209
"""
204-
if isinstance(reference_cursor, str):
205-
reference_cursor = self.decode_cursor(reference_cursor)
206-
207210
find_stmt = self._find_query(search_params)
208211
paginated_stmt = self._cursor_paginated_query(
209212
find_stmt,
210-
reference_cursor=reference_cursor,
213+
cursor_reference=cursor_reference,
211214
is_end_cursor=is_end_cursor,
212215
per_page=items_per_page,
213216
)
@@ -224,6 +227,6 @@ async def cursor_paginated_find(
224227
result_items=result_items,
225228
total_items_count=total_items_count,
226229
items_per_page=items_per_page,
227-
reference_cursor=reference_cursor,
230+
cursor_reference=cursor_reference,
228231
is_end_cursor=is_end_cursor,
229232
)

sqlalchemy_bind_manager/_repository/base_repository.py

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from abc import ABC
2-
from base64 import b64decode, b64encode
32
from enum import Enum
43
from functools import partial
54
from math import ceil
@@ -24,9 +23,9 @@
2423

2524
from .common import (
2625
MODEL,
27-
Cursor,
2826
CursorPageInfo,
2927
CursorPaginatedResult,
28+
CursorReference,
3029
PageInfo,
3130
PaginatedResult,
3231
)
@@ -181,7 +180,7 @@ def _paginate_query_by_page(
181180
def _cursor_paginated_query(
182181
self,
183182
stmt: Select,
184-
reference_cursor: Union[Cursor, None],
183+
cursor_reference: Union[CursorReference, None],
185184
is_end_cursor: bool = False,
186185
per_page: int = _max_query_limit,
187186
) -> Select:
@@ -200,52 +199,52 @@ def _cursor_paginated_query(
200199

201200
forward_limit = self._sanitised_query_limit(per_page) + 1
202201

203-
if not reference_cursor:
202+
if not cursor_reference:
204203
return stmt.limit(forward_limit).order_by( # type: ignore
205204
asc(self._model_pk())
206205
)
207206

208207
# TODO: Use window functions
209208
if not is_end_cursor:
210209
previous_query = stmt.where(
211-
getattr(self._model, reference_cursor.column) <= reference_cursor.value
210+
getattr(self._model, cursor_reference.column) <= cursor_reference.value
212211
)
213212
previous_query = (
214213
self._filter_order_by(
215-
previous_query, [(reference_cursor.column, SortDirection.DESC)]
214+
previous_query, [(cursor_reference.column, SortDirection.DESC)]
216215
)
217216
.limit(1)
218217
.subquery("previous") # type: ignore
219218
)
220219

221220
page_query = stmt.where(
222-
getattr(self._model, reference_cursor.column) > reference_cursor.value
221+
getattr(self._model, cursor_reference.column) > cursor_reference.value
223222
)
224223
page_query = (
225224
self._filter_order_by(
226-
page_query, [(reference_cursor.column, SortDirection.ASC)]
225+
page_query, [(cursor_reference.column, SortDirection.ASC)]
227226
)
228227
.limit(forward_limit)
229228
.subquery("page") # type: ignore
230229
)
231230
else:
232231
previous_query = stmt.where(
233-
getattr(self._model, reference_cursor.column) >= reference_cursor.value
232+
getattr(self._model, cursor_reference.column) >= cursor_reference.value
234233
)
235234
previous_query = (
236235
self._filter_order_by(
237-
previous_query, [(reference_cursor.column, SortDirection.ASC)]
236+
previous_query, [(cursor_reference.column, SortDirection.ASC)]
238237
)
239238
.limit(1)
240239
.subquery("previous") # type: ignore
241240
)
242241

243242
page_query = stmt.where(
244-
getattr(self._model, reference_cursor.column) < reference_cursor.value
243+
getattr(self._model, cursor_reference.column) < cursor_reference.value
245244
)
246245
page_query = (
247246
self._filter_order_by(
248-
page_query, [(reference_cursor.column, SortDirection.DESC)]
247+
page_query, [(cursor_reference.column, SortDirection.DESC)]
249248
)
250249
.limit(forward_limit)
251250
.subquery("page") # type: ignore
@@ -256,7 +255,7 @@ def _cursor_paginated_query(
256255
self._model,
257256
select(previous_query)
258257
.union(select(page_query))
259-
.order_by(reference_cursor.column)
258+
.order_by(cursor_reference.column)
260259
.subquery(), # type: ignore
261260
)
262261
)
@@ -302,7 +301,7 @@ def _build_cursor_paginated_result(
302301
result_items: List[MODEL],
303302
total_items_count: int,
304303
items_per_page: int,
305-
reference_cursor: Union[Cursor, None],
304+
cursor_reference: Union[CursorReference, None],
306305
is_end_cursor: bool,
307306
) -> CursorPaginatedResult:
308307
"""
@@ -312,7 +311,7 @@ def _build_cursor_paginated_result(
312311
:param result_items:
313312
:param total_items_count:
314313
:param items_per_page:
315-
:param reference_cursor:
314+
:param cursor_reference:
316315
:param is_end_cursor:
317316
:return:
318317
"""
@@ -329,7 +328,7 @@ def _build_cursor_paginated_result(
329328
if not result_items:
330329
return result_structure
331330

332-
if not reference_cursor:
331+
if not cursor_reference:
333332
has_previous_page = False
334333
has_next_page = len(result_items) > sanitised_query_limit
335334
if has_next_page:
@@ -341,7 +340,7 @@ def _build_cursor_paginated_result(
341340

342341
elif is_end_cursor:
343342
index = -1
344-
reference_column = reference_cursor.column
343+
reference_column = cursor_reference.column
345344
last_found_cursor_value = getattr(result_items[index], reference_column)
346345
"""
347346
Currently we support only numeric or string model values for cursors,
@@ -352,17 +351,17 @@ def _build_cursor_paginated_result(
352351
9 < 10 but '9' > '10'
353352
"""
354353
if isinstance(last_found_cursor_value, str):
355-
has_next_page = last_found_cursor_value >= reference_cursor.value
354+
has_next_page = last_found_cursor_value >= cursor_reference.value
356355
else:
357-
has_next_page = last_found_cursor_value >= float(reference_cursor.value)
356+
has_next_page = last_found_cursor_value >= float(cursor_reference.value)
358357
if has_next_page:
359358
result_items.pop(index)
360359
has_previous_page = len(result_items) > sanitised_query_limit
361360
if has_previous_page:
362361
result_items = result_items[-sanitised_query_limit:]
363362
else:
364363
index = 0
365-
reference_column = reference_cursor.column
364+
reference_column = cursor_reference.column
366365
first_found_cursor_value = getattr(result_items[index], reference_column)
367366
"""
368367
Currently we support only numeric or string model values for cursors,
@@ -373,10 +372,10 @@ def _build_cursor_paginated_result(
373372
9 < 10 but '9' > '10'
374373
"""
375374
if isinstance(first_found_cursor_value, str):
376-
has_previous_page = first_found_cursor_value <= reference_cursor.value
375+
has_previous_page = first_found_cursor_value <= cursor_reference.value
377376
else:
378377
has_previous_page = first_found_cursor_value <= float(
379-
reference_cursor.value
378+
cursor_reference.value
380379
)
381380
if has_previous_page:
382381
result_items.pop(index)
@@ -389,28 +388,17 @@ def _build_cursor_paginated_result(
389388
result_structure.page_info.has_previous_page = has_previous_page
390389

391390
if result_items:
392-
result_structure.page_info.start_cursor = self.encode_cursor(
393-
Cursor(
394-
column=reference_column,
395-
value=getattr(result_items[0], reference_column),
396-
)
391+
result_structure.page_info.start_cursor = CursorReference(
392+
column=reference_column,
393+
value=getattr(result_items[0], reference_column),
397394
)
398-
result_structure.page_info.end_cursor = self.encode_cursor(
399-
Cursor(
400-
column=reference_column,
401-
value=getattr(result_items[-1], reference_column),
402-
)
395+
result_structure.page_info.end_cursor = CursorReference(
396+
column=reference_column,
397+
value=getattr(result_items[-1], reference_column),
403398
)
404399

405400
return result_structure
406401

407-
def encode_cursor(self, cursor: Cursor) -> str:
408-
serialised_cursor = cursor.json()
409-
return b64encode(serialised_cursor.encode()).decode()
410-
411-
def decode_cursor(self, cursor: str) -> Cursor:
412-
return Cursor.parse_raw(b64decode(cursor))
413-
414402
def _model_pk(self) -> str:
415403
primary_keys = inspect(self._model).primary_key # type: ignore
416404
if len(primary_keys) > 1:

sqlalchemy_bind_manager/_repository/common.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,20 @@ class PaginatedResult(GenericModel, Generic[MODEL]):
2121
page_info: PageInfo
2222

2323

24+
class CursorReference(BaseModel):
25+
column: str
26+
value: str
27+
28+
2429
class CursorPageInfo(BaseModel):
2530
items_per_page: int
2631
total_items: int
2732
has_next_page: bool = False
2833
has_previous_page: bool = False
29-
start_cursor: Union[str, None] = None
30-
end_cursor: Union[str, None] = None
34+
start_cursor: Union[CursorReference, None] = None
35+
end_cursor: Union[CursorReference, None] = None
3136

3237

3338
class CursorPaginatedResult(GenericModel, Generic[MODEL]):
3439
items: List[MODEL]
3540
page_info: CursorPageInfo
36-
37-
38-
class Cursor(BaseModel):
39-
column: str
40-
value: str

sqlalchemy_bind_manager/_repository/sync.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
BaseRepository,
2222
SortDirection,
2323
)
24-
from .common import MODEL, PRIMARY_KEY, Cursor, CursorPaginatedResult, PaginatedResult
24+
from .common import (
25+
MODEL,
26+
PRIMARY_KEY,
27+
CursorPaginatedResult,
28+
CursorReference,
29+
PaginatedResult,
30+
)
2531

2632

2733
class SQLAlchemyRepository(Generic[MODEL], BaseRepository[MODEL], ABC):
@@ -168,7 +174,7 @@ def paginated_find(
168174
def cursor_paginated_find(
169175
self,
170176
items_per_page: int,
171-
reference_cursor: Union[Cursor, str, None] = None,
177+
cursor_reference: Union[CursorReference, None] = None,
172178
is_end_cursor: bool = False,
173179
search_params: Union[None, Mapping[str, Any]] = None,
174180
) -> CursorPaginatedResult[MODEL]:
@@ -189,14 +195,11 @@ def cursor_paginated_find(
189195
:return: A collection of models
190196
:rtype: List
191197
"""
192-
if isinstance(reference_cursor, str):
193-
reference_cursor = self.decode_cursor(reference_cursor)
194-
195198
find_stmt = self._find_query(search_params)
196199

197200
paginated_stmt = self._cursor_paginated_query(
198201
find_stmt,
199-
reference_cursor=reference_cursor,
202+
cursor_reference=cursor_reference,
200203
is_end_cursor=is_end_cursor,
201204
per_page=items_per_page,
202205
)
@@ -211,6 +214,6 @@ def cursor_paginated_find(
211214
result_items=result_items,
212215
total_items_count=total_items_count,
213216
items_per_page=items_per_page,
214-
reference_cursor=reference_cursor,
217+
cursor_reference=cursor_reference,
215218
is_end_cursor=is_end_cursor,
216219
)

sqlalchemy_bind_manager/protocols.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
PRIMARY_KEY,
1515
)
1616
from sqlalchemy_bind_manager.repository import (
17-
Cursor,
1817
CursorPaginatedResult,
18+
CursorReference,
1919
PaginatedResult,
2020
SortDirection,
2121
)
@@ -54,7 +54,7 @@ async def paginated_find(
5454
async def cursor_paginated_find(
5555
self,
5656
items_per_page: int,
57-
reference_cursor: Union[Cursor, str, None] = None,
57+
cursor_reference: Union[CursorReference, None] = None,
5858
is_end_cursor: bool = False,
5959
search_params: Union[None, Mapping[str, Any]] = None,
6060
) -> CursorPaginatedResult[MODEL]:
@@ -94,7 +94,7 @@ def paginated_find(
9494
def cursor_paginated_find(
9595
self,
9696
items_per_page: int,
97-
reference_cursor: Union[Cursor, str, None] = None,
97+
cursor_reference: Union[CursorReference, None] = None,
9898
is_end_cursor: bool = False,
9999
search_params: Union[None, Mapping[str, Any]] = None,
100100
) -> CursorPaginatedResult[MODEL]:

sqlalchemy_bind_manager/repository.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ._repository import (
2-
Cursor,
32
CursorPaginatedResult,
3+
CursorReference,
44
PaginatedResult,
55
SortDirection,
66
SQLAlchemyAsyncRepository,

0 commit comments

Comments
 (0)