Skip to content

Commit 73b70db

Browse files
authored
Merge pull request #11 from RSS-Engineering/support-pagination-dynamodb
Support pagination overlaying indexes in DynamoDB
2 parents fb68cdf + c11d1a7 commit 73b70db

File tree

3 files changed

+75
-51
lines changed

3 files changed

+75
-51
lines changed

pydanticrud/backends/dynamodb.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Optional, Set, Any, Dict
1+
from typing import Optional, Set
22
import logging
33
import json
44
from datetime import datetime
5+
from base64 import b64encode, b64decode
56

67
import boto3
78
from boto3.dynamodb.conditions import Key, Attr
@@ -117,48 +118,9 @@ def index_definition(index_name, keys, gsi=False):
117118
return schema
118119

119120

120-
class DynamoIterableResult(IterableResult):
121-
def __init__(self, cls, result, serialized_items):
122-
super(DynamoIterableResult, self).__init__(cls, serialized_items, result.get("Count"))
123-
124-
self.last_evaluated_key = None
125-
lsk = result.get("LastEvaluatedKey")
126-
if lsk:
127-
_key = [lsk[cls.Config.hash_key]]
128-
if cls.Config.range_key:
129-
_key.append(lsk[cls.Config.range_key])
130-
self.last_evaluated_key = tuple(_key)
131-
132-
self.scanned_count = result["ScannedCount"]
133-
134-
135-
class Backend:
136-
def __init__(self, cls):
137-
cfg = cls.Config
138-
self.cls = cls
139-
self.schema = cls.schema()
140-
self.hash_key = cfg.hash_key
141-
self.range_key = getattr(cfg, 'range_key', None)
142-
self.table_name = cls.get_table_name()
143-
144-
self.local_indexes = getattr(cfg, "local_indexes", {})
145-
self.global_indexes = getattr(cfg, "global_indexes", {})
146-
self.index_map = {(self.hash_key,): None}
147-
self.possible_keys = {self.hash_key}
148-
if self.range_key:
149-
self.possible_keys.add(self.range_key)
150-
self.index_map = {(self.hash_key, self.range_key): None}
151-
152-
for name, keys in dict(**self.local_indexes, **self.global_indexes).items():
153-
self.index_map[keys] = name
154-
for key in keys:
155-
self.possible_keys.add(key)
156-
157-
self.dynamodb = boto3.resource(
158-
"dynamodb",
159-
region_name=getattr(cfg, "region", "us-east-2"),
160-
endpoint_url=getattr(cfg, "endpoint", None),
161-
)
121+
class DynamoSerializer:
122+
def __init__(self, schema):
123+
self.schema = schema
162124

163125
def _serialize_field(self, field_name, value):
164126
definition = self.schema.get("definitions")
@@ -175,7 +137,7 @@ def _serialize_field(self, field_name, value):
175137
log.debug(f"No serializer for field_type {field_type}")
176138
return value # do nothing but log it.
177139

178-
def _serialize_record(self, data_dict) -> dict:
140+
def serialize_record(self, data_dict) -> dict:
179141
"""
180142
Apply converters to non-native types
181143
"""
@@ -198,7 +160,7 @@ def _deserialize_field(self, field_name, value):
198160
log.debug(f"No deserializer for field_type {field_type}")
199161
return value # do nothing but log it.
200162

201-
def _deserialize_record(self, data_dict) -> dict:
163+
def deserialize_record(self, data_dict) -> dict:
202164
"""
203165
Apply converters to non-native types
204166
"""
@@ -207,6 +169,44 @@ def _deserialize_record(self, data_dict) -> dict:
207169
for field_name, value in data_dict.items()
208170
}
209171

172+
173+
class DynamoIterableResult(IterableResult):
174+
def __init__(self, cls, result, serialized_items):
175+
super(DynamoIterableResult, self).__init__(cls, serialized_items, result.get("Count"))
176+
177+
self.last_evaluated_key = result.get("LastEvaluatedKey")
178+
self.scanned_count = result["ScannedCount"]
179+
180+
181+
class Backend:
182+
def __init__(self, cls):
183+
cfg = cls.Config
184+
self.cls = cls
185+
self.schema = cls.schema()
186+
self.serializer = DynamoSerializer(self.schema)
187+
self.hash_key = cfg.hash_key
188+
self.range_key = getattr(cfg, 'range_key', None)
189+
self.table_name = cls.get_table_name()
190+
191+
self.local_indexes = getattr(cfg, "local_indexes", {})
192+
self.global_indexes = getattr(cfg, "global_indexes", {})
193+
self.index_map = {(self.hash_key,): None}
194+
self.possible_keys = {self.hash_key}
195+
if self.range_key:
196+
self.possible_keys.add(self.range_key)
197+
self.index_map = {(self.hash_key, self.range_key): None}
198+
199+
for name, keys in dict(**self.local_indexes, **self.global_indexes).items():
200+
self.index_map[keys] = name
201+
for key in keys:
202+
self.possible_keys.add(key)
203+
204+
self.dynamodb = boto3.resource(
205+
"dynamodb",
206+
region_name=getattr(cfg, "region", "us-east-2"),
207+
endpoint_url=getattr(cfg, "endpoint", None),
208+
)
209+
210210
def _key_param_to_dict(self, key):
211211
_key = {
212212
self.hash_key: key,
@@ -295,7 +295,7 @@ def query(self,
295295
query_expr: Optional[Rule] = None,
296296
filter_expr: Optional[Rule] = None,
297297
limit: Optional[int] = None,
298-
exclusive_start_key: Optional[tuple[Any]] = None,
298+
exclusive_start_key: Optional[str] = None,
299299
order: str = 'asc',
300300
):
301301
table = self.get_table()
@@ -306,7 +306,7 @@ def query(self,
306306
if limit:
307307
params["Limit"] = limit
308308
if exclusive_start_key:
309-
params["ExclusiveStartKey"] = self._key_param_to_dict(exclusive_start_key)
309+
params["ExclusiveStartKey"] = exclusive_start_key
310310
if f_expr:
311311
params["FilterExpression"] = f_expr
312312

@@ -347,7 +347,7 @@ def query(self,
347347
return []
348348
raise e
349349

350-
return DynamoIterableResult(self.cls, resp, (self._deserialize_record(rec) for rec in resp["Items"]))
350+
return DynamoIterableResult(self.cls, resp, (self.serializer.deserialize_record(rec) for rec in resp["Items"]))
351351

352352
def get(self, key):
353353
_key = self._key_param_to_dict(key)
@@ -363,10 +363,10 @@ def get(self, key):
363363
_key = key
364364
raise DoesNotExist(f'{self.table_name} "{_key}" does not exist')
365365

366-
return self._deserialize_record(resp["Item"])
366+
return self.serializer.deserialize_record(resp["Item"])
367367

368368
def save(self, item, condition: Optional[Rule] = None) -> bool:
369-
data = self._serialize_record(item.dict(by_alias=True))
369+
data = self.serializer.serialize_record(item.dict(by_alias=True))
370370

371371
try:
372372
if condition:

pydanticrud/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def __init__(self, cls, records, count=None):
1919

2020
self._current_index = 0
2121

22+
def __len__(self):
23+
return self.count
24+
2225
def __iter__(self):
2326
return self
2427

tests/test_dynamodb.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Dict, List, Optional
2+
from base64 import b64decode, b64encode
23
from decimal import Decimal
34
from datetime import datetime
5+
import json
46
from uuid import uuid4
57
import random
68

@@ -399,7 +401,7 @@ def test_pagination_query_with_hash_key_complex(dynamo, complex_query_data, orde
399401
if m["account"] == middle_record['account'] and m["sort_date_key"] >= middle_record['sort_date_key']
400402
], reverse=order == 'desc')[:page_size]
401403
assert res_data == check_data
402-
assert res.last_evaluated_key == check_data[-1]
404+
assert res.last_evaluated_key == {"account": check_data[-1][0], "sort_date_key": check_data[-1][1]}
403405

404406
res = ComplexKeyModel.query(query_rule, order=order, limit=page_size, exclusive_start_key=res.last_evaluated_key)
405407
res_data = [(m.account, m.sort_date_key) for m in res]
@@ -411,6 +413,25 @@ def test_pagination_query_with_hash_key_complex(dynamo, complex_query_data, orde
411413
assert res_data == check_data
412414

413415

416+
def test_pagination_query_with_index_complex(dynamo, complex_query_data):
417+
page_size = 2
418+
middle_record = complex_query_data[(len(complex_query_data)//2)]
419+
query_rule = Rule(f"account == '{middle_record['account']}' and category_id >= {middle_record['category_id']}")
420+
check_data = ComplexKeyModel.query(query_rule)
421+
res = ComplexKeyModel.query(query_rule, limit=page_size)
422+
res_data = [{"account": m.account, "category_id": m.category_id, "sort_date_key": m.sort_date_key} for m in res]
423+
# We only check for inclusion because the category index order is not going to be the same and since there are
424+
# multiple records per category, it's unknowable outside of the query response.
425+
assert all([r in check_data for r in res])
426+
assert len(res) == page_size
427+
assert res.last_evaluated_key == {"account": res_data[-1]["account"], "category_id": res_data[-1]["category_id"],
428+
"sort_date_key": res_data[-1]["sort_date_key"]}
429+
430+
res = ComplexKeyModel.query(query_rule, limit=page_size, exclusive_start_key=res.last_evaluated_key)
431+
assert all([r in check_data for r in res])
432+
assert len(res) == page_size
433+
434+
414435
def test_query_errors_with_nonprimary_key_complex(dynamo, complex_query_data):
415436
data_by_expires = complex_query_data[:]
416437
data_by_expires.sort(key=lambda d: d["expires"])

0 commit comments

Comments
 (0)