Skip to content

Commit fb68cdf

Browse files
authored
Merge pull request #10 from RSS-Engineering/support-pagination-dynamodb
Support pagination in DynamoDB
2 parents 0635416 + 8200d3d commit fb68cdf

File tree

4 files changed

+132
-18
lines changed

4 files changed

+132
-18
lines changed

README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,15 @@ NOTE: Rule complexity is limited by the querying capabilities of the backend.
5353

5454
### DynamoDB
5555

56-
`query(query_expr: Optional[Rule], filter_expr: Optional[Rule])` - Providing a
57-
`query_expr` parameter will try to apply the keys of the expression to an
58-
existing index. Providing a `filter_expr` parameter will filter the results of
56+
`query(query_expr: Optional[Rule], filter_expr: Optional[Rule], limit: Optional[str], exclusive_start_key: Optional[tuple[Any]], order: str = 'asc'`
57+
- Providing a `query_expr` parameter will try to apply the keys of the expression to an
58+
existing index.
59+
- Providing a `filter_expr` parameter will filter the results of
5960
a passed `query_expr` or run a dynamodb `scan` if no `query_expr` is passed.
60-
An empty call to `query()` will return the scan results (and be resource
61+
- An empty call to `query()` will return the scan results (and be resource
6162
intensive).
63+
- Providing a `limit` parameter will limit the number of results. If more results remain, the returned dataset will have an `last_evaluated_key` property that can be passed to `exclusive_start_key` to continue with the next page.
64+
- Providing `order='desc'` will return the result set in descending order. This is not available for query calls that "scan" dynamodb.
6265

6366
## Backend Configuration Members
6467

pydanticrud/backends/dynamodb.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Set
1+
from typing import Optional, Set, Any, Dict
22
import logging
33
import json
44
from datetime import datetime
@@ -9,6 +9,7 @@
99
from botocore.exceptions import ClientError
1010
from rule_engine import Rule, ast, types
1111

12+
from ..main import IterableResult
1213
from ..exceptions import DoesNotExist, ConditionCheckFailed
1314

1415
log = logging.getLogger(__name__)
@@ -116,9 +117,25 @@ def index_definition(index_name, keys, gsi=False):
116117
return schema
117118

118119

120+
class DynamoIterableResult(IterableResult):
121+
def __init__(self, cls, result, serialized_items):
122+
super(DynamoIterableResult, self).__init__(cls, serialized_items, result.get("Count"))
123+
124+
self.last_evaluated_key = None
125+
lsk = result.get("LastEvaluatedKey")
126+
if lsk:
127+
_key = [lsk[cls.Config.hash_key]]
128+
if cls.Config.range_key:
129+
_key.append(lsk[cls.Config.range_key])
130+
self.last_evaluated_key = tuple(_key)
131+
132+
self.scanned_count = result["ScannedCount"]
133+
134+
119135
class Backend:
120136
def __init__(self, cls):
121137
cfg = cls.Config
138+
self.cls = cls
122139
self.schema = cls.schema()
123140
self.hash_key = cfg.hash_key
124141
self.range_key = getattr(cfg, 'range_key', None)
@@ -274,11 +291,22 @@ def exists(self):
274291
except ClientError:
275292
return False
276293

277-
def query(self, query_expr: Optional[Rule] = None, filter_expr: Optional[Rule] = None):
294+
def query(self,
295+
query_expr: Optional[Rule] = None,
296+
filter_expr: Optional[Rule] = None,
297+
limit: Optional[int] = None,
298+
exclusive_start_key: Optional[tuple[Any]] = None,
299+
order: str = 'asc',
300+
):
278301
table = self.get_table()
279302
f_expr, _ = rule_to_boto_expression(filter_expr) if filter_expr else (None, set())
280303

281304
params = {}
305+
306+
if limit:
307+
params["Limit"] = limit
308+
if exclusive_start_key:
309+
params["ExclusiveStartKey"] = self._key_param_to_dict(exclusive_start_key)
282310
if f_expr:
283311
params["FilterExpression"] = f_expr
284312

@@ -291,6 +319,9 @@ def query(self, query_expr: Optional[Rule] = None, filter_expr: Optional[Rule] =
291319
index_name = self._get_best_index(keys_used)
292320
params["KeyConditionExpression"] = q_expr
293321

322+
if order != 'asc':
323+
params["ScanIndexForward"] = False
324+
294325
if index_name:
295326
params["IndexName"] = index_name
296327
elif not keys_used.issubset({self.hash_key, self.range_key}):
@@ -305,16 +336,18 @@ def query(self, query_expr: Optional[Rule] = None, filter_expr: Optional[Rule] =
305336
except DynamoDBNeedsKeyConditionError:
306337
raise ConditionCheckFailed("Non-key attributes are not valid in the query expression. Use filter "
307338
"expression")
308-
309339
else:
340+
if order != 'asc':
341+
raise ConditionCheckFailed("Scans do not support reverse order.")
342+
310343
try:
311344
resp = table.scan(**params)
312345
except ClientError as e:
313346
if e.response["Error"]["Code"] == "ResourceNotFoundException":
314347
return []
315348
raise e
316349

317-
return [self._deserialize_record(rec) for rec in resp["Items"]]
350+
return DynamoIterableResult(self.cls, resp, (self._deserialize_record(rec) for rec in resp["Items"]))
318351

319352
def get(self, key):
320353
_key = self._key_param_to_dict(key)

pydanticrud/main.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,26 @@ def __new__(mcs, name, bases, namespace, **kwargs):
1212
return cls
1313

1414

15+
class IterableResult:
16+
def __init__(self, cls, records, count=None):
17+
self.records = [cls.parse_obj(i) for i in records]
18+
self.count = count # None indicates "unknown"
19+
20+
self._current_index = 0
21+
22+
def __iter__(self):
23+
return self
24+
25+
def __next__(self):
26+
try:
27+
member = self.records[self._current_index]
28+
self._current_index += 1
29+
return member
30+
except IndexError:
31+
self._current_index = 0
32+
raise StopIteration
33+
34+
1535
class BaseModel(PydanticBaseModel, metaclass=CrudMetaClass):
1636
@classmethod
1737
def initialize(cls):
@@ -28,7 +48,9 @@ def exists(cls) -> bool:
2848
@classmethod
2949
def query(cls, *args, **kwargs):
3050
res = cls.__backend__.query(*args, **kwargs)
31-
return [cls.parse_obj(i) for i in res]
51+
if not isinstance(res, IterableResult):
52+
res = IterableResult(cls, res)
53+
return res
3254

3355
@classmethod
3456
def get(cls, *args, **kwargs):

tests/test_dynamodb.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,14 @@ def simple_query_data(simple_table):
211211

212212
@pytest.fixture(scope="module")
213213
def complex_query_data(complex_table):
214-
presets = [dict()] * 20
215-
data = [datum for datum in [complex_model_data_generator(**i) for i in presets]]
214+
record_count = 500
215+
presets = [dict()] * record_count
216+
accounts = [str(uuid4()) for i in range(4)]
217+
218+
data = [
219+
complex_model_data_generator(account=accounts[i % 4], **p)
220+
for i, p in enumerate(presets)
221+
]
216222
for datum in data:
217223
ComplexKeyModel.parse_obj(datum).save()
218224
try:
@@ -234,6 +240,7 @@ def alias_query_data(alias_table):
234240
for datum in data:
235241
AliasKeyModel.delete(datum["name"])
236242

243+
237244
@pytest.fixture(scope="module")
238245
def nested_query_data(nested_table):
239246
presets = [dict()] * 5
@@ -249,7 +256,7 @@ def nested_query_data(nested_table):
249256

250257

251258
@pytest.fixture(scope="module")
252-
def nested_query_data_optional(nested_table):
259+
def nested_query_data_empty_ticket(nested_table):
253260
presets = [dict()] * 5
254261
data = [datum for datum in [nested_model_data_generator(include_ticket=False, **i) for i in presets]]
255262
for datum in data:
@@ -282,6 +289,14 @@ def test_query_with_hash_key_simple(dynamo, simple_query_data):
282289
assert res_data == {simple_query_data[0]["name"]: simple_query_data[0]}
283290

284291

292+
def test_scan_errors_with_order(dynamo, simple_query_data):
293+
data_by_timestamp = simple_query_data[:]
294+
data_by_timestamp.sort(key=lambda d: d["timestamp"])
295+
with pytest.raises(ConditionCheckFailed,
296+
match=r"Scans do not support reverse order."):
297+
SimpleKeyModel.query(order='desc')
298+
299+
285300
def test_query_errors_with_nonprimary_key_simple(dynamo, simple_query_data):
286301
data_by_timestamp = simple_query_data[:]
287302
data_by_timestamp.sort(key=lambda d: d["timestamp"])
@@ -347,13 +362,55 @@ def test_query_with_hash_key_complex(dynamo, complex_query_data):
347362
res_data = {(m.account, m.sort_date_key): m.dict() for m in res}
348363
assert res_data == {(record["account"], record["sort_date_key"]): record}
349364

350-
# Check that it works regardless of order
365+
# Check that it works regardless of query attribute order
351366
res = ComplexKeyModel.query(
352367
Rule(f"sort_date_key == '{record['sort_date_key']}' and account == '{record['account']}'"))
353368
res_data = {(m.account, m.sort_date_key): m.dict() for m in res}
354369
assert res_data == {(record["account"], record["sort_date_key"]): record}
355370

356371

372+
@pytest.mark.parametrize('order', ('asc', 'desc'))
373+
def test_ordered_query_with_hash_key_complex(dynamo, complex_query_data, order):
374+
middle_record = complex_query_data[(len(complex_query_data)//2)]
375+
res = ComplexKeyModel.query(
376+
Rule(f"account == '{middle_record['account']}' and sort_date_key >= '{middle_record['sort_date_key']}'"),
377+
order=order
378+
)
379+
res_data = [(m.account, m.sort_date_key) for m in res]
380+
check_data = sorted([
381+
(m["account"], m["sort_date_key"])
382+
for m in complex_query_data
383+
if m["account"] == middle_record['account'] and m["sort_date_key"] >= middle_record['sort_date_key']
384+
], reverse=order == 'desc')
385+
386+
assert res_data == check_data
387+
388+
389+
@pytest.mark.parametrize('order', ('asc', 'desc'))
390+
def test_pagination_query_with_hash_key_complex(dynamo, complex_query_data, order):
391+
page_size = 2
392+
middle_record = complex_query_data[(len(complex_query_data)//2)]
393+
query_rule = Rule(f"account == '{middle_record['account']}' and sort_date_key >= '{middle_record['sort_date_key']}'")
394+
res = ComplexKeyModel.query(query_rule, order=order, limit=page_size)
395+
res_data = [(m.account, m.sort_date_key) for m in res]
396+
check_data = sorted([
397+
(m["account"], m["sort_date_key"])
398+
for m in complex_query_data
399+
if m["account"] == middle_record['account'] and m["sort_date_key"] >= middle_record['sort_date_key']
400+
], reverse=order == 'desc')[:page_size]
401+
assert res_data == check_data
402+
assert res.last_evaluated_key == check_data[-1]
403+
404+
res = ComplexKeyModel.query(query_rule, order=order, limit=page_size, exclusive_start_key=res.last_evaluated_key)
405+
res_data = [(m.account, m.sort_date_key) for m in res]
406+
check_data = sorted([
407+
(m["account"], m["sort_date_key"])
408+
for m in complex_query_data
409+
if m["account"] == middle_record['account'] and m["sort_date_key"] >= middle_record['sort_date_key']
410+
], reverse=order == 'desc')[page_size:page_size*2]
411+
assert res_data == check_data
412+
413+
357414
def test_query_errors_with_nonprimary_key_complex(dynamo, complex_query_data):
358415
data_by_expires = complex_query_data[:]
359416
data_by_expires.sort(key=lambda d: d["expires"])
@@ -381,15 +438,13 @@ def test_query_scan_complex(dynamo, complex_query_data):
381438

382439

383440
def test_query_with_nested_model(dynamo, nested_query_data):
384-
data_by_expires = nested_model_data_generator()
385-
res = NestedModel.query(filter_expr=Rule(f"expires <= '{data_by_expires['expires']}'"))
441+
res = NestedModel.query()
386442
res_data = [m.ticket for m in res]
387443
assert any(elem is not None for elem in res_data)
388444

389445

390-
def test_query_with_nested_model_optional(dynamo, nested_query_data_optional):
391-
data_by_expires = nested_model_data_generator(include_ticket=False)
392-
res = NestedModel.query(filter_expr=Rule(f"expires <= '{data_by_expires['expires']}'"))
446+
def test_query_with_nested_model_optional(dynamo, nested_query_data_empty_ticket):
447+
res = NestedModel.query()
393448
res_data = [m.ticket for m in res]
394449
assert any(elem is None for elem in res_data)
395450

@@ -404,6 +459,7 @@ def test_query_alias_save(dynamo):
404459
except Exception as e:
405460
raise pytest.fail("Failed to save Alias model!")
406461

462+
407463
def test_get_alias_model_data(dynamo, alias_query_data):
408464
data = alias_model_data_generator()
409465
res = AliasKeyModel.get(alias_query_data[0]['name'])

0 commit comments

Comments
 (0)