Skip to content

Commit 4dcb05e

Browse files
committed
Make sure CursorReference values are not coerced and remove relevant logic
1 parent d83ef25 commit 4dcb05e

File tree

7 files changed

+64
-32
lines changed

7 files changed

+64
-32
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ testpaths = [
6666

6767
[tool.mypy]
6868
files = "sqlalchemy_bind_manager"
69+
plugins = "pydantic.mypy"
6970

7071
[tool.ruff]
7172
select = ["E", "F", "I"]

sqlalchemy_bind_manager/_repository/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Generic, List, TypeVar, Union
22

3-
from pydantic import BaseModel
3+
from pydantic import BaseModel, StrictInt, StrictStr
44
from pydantic.generics import GenericModel
55

66
MODEL = TypeVar("MODEL")
@@ -23,7 +23,7 @@ class PaginatedResult(GenericModel, Generic[MODEL]):
2323

2424
class CursorReference(BaseModel):
2525
column: str
26-
value: str
26+
value: Union[StrictStr, StrictInt]
2727

2828

2929
class CursorPageInfo(BaseModel):

sqlalchemy_bind_manager/_repository/result_presenters.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from math import ceil
2-
from typing import Collection, List, Union
2+
from typing import List, Union
33

44
from sqlalchemy import inspect
55

@@ -56,18 +56,11 @@ def build_result(
5656
index = -1
5757
reference_column = cursor_reference.column
5858
last_found_cursor_value = getattr(result_items[index], reference_column)
59-
"""
60-
Currently we support only numeric or string model values for cursors,
61-
but pydantic models (cursor) coerce always the value as string.
62-
This mean if the value is not actually string we need to cast to
63-
ensure correct ordering is evaluated.
64-
e.g.
65-
9 < 10 but '9' > '10'
66-
"""
67-
if isinstance(last_found_cursor_value, str):
68-
has_next_page = last_found_cursor_value >= cursor_reference.value
69-
else:
70-
has_next_page = last_found_cursor_value >= float(cursor_reference.value)
59+
if not isinstance(last_found_cursor_value, type(cursor_reference.value)):
60+
raise TypeError(
61+
"Values from CursorReference and results must be of the same type"
62+
)
63+
has_next_page = last_found_cursor_value >= cursor_reference.value
7164
if has_next_page:
7265
result_items.pop(index)
7366
has_previous_page = len(result_items) > items_per_page
@@ -77,20 +70,11 @@ def build_result(
7770
index = 0
7871
reference_column = cursor_reference.column
7972
first_found_cursor_value = getattr(result_items[index], reference_column)
80-
"""
81-
Currently we support only numeric or string model values for cursors,
82-
but pydantic models (cursor) coerce always the value as string.
83-
This mean if the value is not actually string we need to cast to
84-
ensure correct ordering is evaluated.
85-
e.g.
86-
9 < 10 but '9' > '10'
87-
"""
88-
if isinstance(first_found_cursor_value, str):
89-
has_previous_page = first_found_cursor_value <= cursor_reference.value
90-
else:
91-
has_previous_page = first_found_cursor_value <= float(
92-
cursor_reference.value
73+
if not isinstance(first_found_cursor_value, type(cursor_reference.value)):
74+
raise TypeError(
75+
"Values from CursorReference and results must be of the same type"
9376
)
77+
has_previous_page = first_found_cursor_value <= cursor_reference.value
9478
if has_previous_page:
9579
result_items.pop(index)
9680
has_next_page = len(result_items) > items_per_page
@@ -117,21 +101,20 @@ def build_result(
117101
class PaginatedResultPresenter:
118102
@staticmethod
119103
def build_result(
120-
result_items: Collection[MODEL],
104+
result_items: List[MODEL],
121105
total_items_count: int,
122106
page: int,
123107
items_per_page: int,
124108
) -> PaginatedResult:
125-
126109
total_pages = (
127110
0
128111
if total_items_count == 0 or total_items_count is None
129112
else ceil(total_items_count / items_per_page)
130113
)
131114

132115
_page = 0 if len(result_items) == 0 else min(page, total_pages)
133-
has_next_page = _page and _page < total_pages
134-
has_previous_page = _page and _page > 1
116+
has_next_page = bool(_page and _page < total_pages)
117+
has_previous_page = bool(_page and _page > 1)
135118

136119
return PaginatedResult(
137120
items=result_items,

tests/repository/result_presenters/__init__.py

Whitespace-only changes.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from unittest.mock import Mock, patch
2+
3+
import pytest
4+
5+
from sqlalchemy_bind_manager._repository.result_presenters import _pk_from_result_object
6+
7+
8+
def test_exception_raised_if_multiple_primary_keys():
9+
with patch(
10+
"sqlalchemy_bind_manager._repository.result_presenters.inspect",
11+
return_value=Mock(primary_key=["1", "2"]),
12+
), pytest.raises(NotImplementedError):
13+
_pk_from_result_object("irrelevant")
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from dataclasses import dataclass
2+
3+
import pytest
4+
5+
from sqlalchemy_bind_manager._repository import CursorReference
6+
from sqlalchemy_bind_manager._repository.result_presenters import (
7+
CursorPaginatedResultPresenter,
8+
)
9+
10+
11+
@dataclass
12+
class MyModel:
13+
model_id: int
14+
name: str
15+
16+
17+
@pytest.mark.parametrize(["is_end_cursor"], [(True,), (False,)])
18+
def test_fails_if_reference_cursor_wrong_type(is_end_cursor):
19+
with pytest.raises(TypeError):
20+
CursorPaginatedResultPresenter.build_result(
21+
result_items=[MyModel(model_id=1, name="test")],
22+
total_items_count=10,
23+
items_per_page=1,
24+
cursor_reference=CursorReference(column="model_id", value="1"),
25+
is_end_cursor=is_end_cursor,
26+
)

tests/repository/test_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from sqlalchemy_bind_manager._repository import CursorReference
2+
3+
4+
def test_cursor_reference_doesnt_coerce_values():
5+
r = CursorReference(
6+
column="column_name",
7+
value=10,
8+
)
9+
assert isinstance(r.value, int)

0 commit comments

Comments
 (0)