Skip to content

Commit 5dd305d

Browse files
authored
Merge pull request #21 from RSS-Engineering/add_count
Add `count()` for dynamodb backends
2 parents 91839de + b6d23cb commit 5dd305d

File tree

5 files changed

+41
-2
lines changed

5 files changed

+41
-2
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ NOTE: Rule complexity is limited by the querying capabilities of the backend.
7272
- Providing a `limit` parameter will limit the number of results. If more results remain, the returned dataset will have an `last_evaluated_key` property that can be passed to `exclusive_start_key` to continue with the next page.
7373
- Providing `order='desc'` will return the result set in descending order. This is not available for query calls that "scan" dynamodb.
7474

75+
`count(query_expr: Optional[Rule], exclusive_start_key: Optional[tuple[Any]], order: str = 'asc'`
76+
- Same as `query` but returns an integer count as total. (When calling `query` with a limit, the count dynamodb returns is <= the limit you provide)
77+
78+
7579
## Backend Configuration Members
7680

7781
`hash_key` - the name of the key field for the backend table

pydanticrud/backends/dynamodb.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def query(
321321
limit: Optional[int] = None,
322322
exclusive_start_key: Optional[str] = None,
323323
order: str = "asc",
324+
select: Optional[str] = None,
324325
):
325326
table = self.get_table()
326327
f_expr, _ = rule_to_boto_expression(filter_expr) if filter_expr else (None, set())
@@ -348,6 +349,9 @@ def query(
348349
if order != "asc":
349350
params["ScanIndexForward"] = False
350351

352+
if select:
353+
params["Select"] = select
354+
351355
if index_name:
352356
params["IndexName"] = index_name
353357
elif not keys_used.issubset({self.hash_key, self.range_key}):
@@ -376,8 +380,28 @@ def query(
376380
raise e
377381

378382
return DynamoIterableResult(
379-
self.cls, resp, (self.serializer.deserialize_record(rec) for rec in resp["Items"])
383+
self.cls,
384+
resp,
385+
(self.serializer.deserialize_record(rec) for rec in resp.get("Items", [])),
386+
)
387+
388+
def count(
389+
self,
390+
query_expr: Optional[Rule] = None,
391+
exclusive_start_key: Optional[str] = None,
392+
order: str = "asc",
393+
) -> int:
394+
"""
395+
Dynamo Query returns a full "scanned_count" but when a limit is specified this count is <= the limit. To
396+
get a full count (i.e. for pagination), a limitless query must be run.
397+
"""
398+
result = self.query(
399+
query_expr=query_expr,
400+
exclusive_start_key=exclusive_start_key,
401+
order=order,
402+
select="COUNT",
380403
)
404+
return result.scanned_count
381405

382406
def get(self, key: Union[Dict, Any]):
383407
if isinstance(key, dict):

pydanticrud/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def query(cls, *args, **kwargs):
5959
res = IterableResult(cls, res)
6060
return res
6161

62+
@classmethod
63+
def count(cls, *args, **kwargs):
64+
return cls.__backend__.count(*args, **kwargs)
65+
6266
@classmethod
6367
def get(cls, *args, **kwargs):
6468
return cls.parse_obj(cls.__backend__.get(*args, **kwargs))

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pydanticrud"
3-
version = "0.4.1"
3+
version = "0.4.2"
44
description = "Supercharge your Pydantic models with CRUD methods and a pluggable backend"
55
authors = ["Timothy Farrell <[email protected]>"]
66
license = "MIT"

tests/test_dynamodb.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,13 @@ def test_pagination_query_with_index_complex(dynamo, complex_query_data):
474474
assert all([r in check_data for r in res])
475475
assert len(res) == page_size
476476

477+
def test_pagination_query_count(dynamo, complex_query_data):
478+
page_size = 2
479+
middle_record = complex_query_data[(len(complex_query_data)//2)]
480+
query_rule = Rule(f"account == '{middle_record['account']}' and category_id >= {middle_record['category_id']}")
481+
check_data = ComplexKeyModel.query(query_rule)
482+
res_count = ComplexKeyModel.count(query_rule)
483+
assert res_count == check_data.scanned_count
477484

478485
def test_query_errors_with_nonprimary_key_complex(dynamo, complex_query_data):
479486
data_by_expires = complex_query_data[:]

0 commit comments

Comments
 (0)