Skip to content

Commit 31c4bd0

Browse files
authored
fix: support pyarrow schema in DynamoDB read_items #2399 (#2401)
* fix: support pyarrow schema in DynamoDB read_items #2399 --------- Signed-off-by: Abdel Jaidi <[email protected]>
1 parent af30766 commit 31c4bd0

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

awswrangler/dynamodb/_read.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import itertools
44
import logging
5+
import warnings
56
from functools import wraps
67
from typing import (
78
TYPE_CHECKING,
@@ -164,7 +165,8 @@ def _convert_items(
164165
mapping=[
165166
{k: v.value if isinstance(v, Binary) else v for k, v in d.items()} # type: ignore[attr-defined]
166167
for d in items
167-
]
168+
],
169+
schema=arrow_kwargs.pop("schema", None),
168170
)
169171
],
170172
arrow_kwargs,
@@ -187,6 +189,7 @@ def _read_scan_chunked(
187189
dynamodb_client: Optional["DynamoDBClient"],
188190
as_dataframe: bool,
189191
kwargs: Dict[str, Any],
192+
schema: Optional[pa.Schema] = None,
190193
segment: Optional[int] = None,
191194
) -> Union[Iterator[pa.Table], Iterator[_ItemsListType]]:
192195
# SEE: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Scan.html#Scan.ParallelScan
@@ -210,7 +213,7 @@ def _read_scan_chunked(
210213
for d in response.get("Items", [])
211214
]
212215
total_items += len(items)
213-
yield _utils.list_to_arrow_table(mapping=items) if as_dataframe else items
216+
yield _utils.list_to_arrow_table(mapping=items, schema=schema) if as_dataframe else items
214217

215218
if ("Limit" in kwargs) and (total_items >= kwargs["Limit"]):
216219
break
@@ -229,13 +232,14 @@ def _read_scan(
229232
dynamodb_client: Optional["DynamoDBClient"],
230233
as_dataframe: bool,
231234
kwargs: Dict[str, Any],
235+
schema: Optional[pa.Schema],
232236
segment: int,
233237
) -> Union[pa.Table, _ItemsListType]:
234-
items_iterator: Iterator[_ItemsListType] = _read_scan_chunked(dynamodb_client, False, kwargs, segment)
238+
items_iterator: Iterator[_ItemsListType] = _read_scan_chunked(dynamodb_client, False, kwargs, None, segment)
235239

236240
items = list(itertools.chain.from_iterable(items_iterator))
237241

238-
return _utils.list_to_arrow_table(mapping=items) if as_dataframe else items
242+
return _utils.list_to_arrow_table(mapping=items, schema=schema) if as_dataframe else items
239243

240244

241245
def _read_query_chunked(
@@ -326,10 +330,11 @@ def _read_items_scan(
326330

327331
kwargs = _serialize_kwargs(kwargs)
328332
kwargs["TableName"] = table_name
333+
schema = arrow_kwargs.pop("schema", None)
329334

330335
if chunked:
331336
_logger.debug("Scanning DynamoDB table %s and returning results in an iterator", table_name)
332-
scan_iterator = _read_scan_chunked(dynamodb_client, as_dataframe, kwargs)
337+
scan_iterator = _read_scan_chunked(dynamodb_client, as_dataframe, kwargs, schema)
333338
if as_dataframe:
334339
return (_utils.table_refs_to_df([items], arrow_kwargs) for items in scan_iterator)
335340

@@ -347,6 +352,7 @@ def _read_items_scan(
347352
dynamodb_client,
348353
itertools.repeat(as_dataframe),
349354
itertools.repeat(kwargs),
355+
itertools.repeat(schema),
350356
range(total_segments),
351357
)
352358

@@ -400,6 +406,10 @@ def _read_items(
400406
items = _read_query(table_name, chunked, boto3_session, **kwargs)
401407
else:
402408
# Last resort use Scan
409+
warnings.warn(
410+
f"Attempting DynamoDB Scan operation with arguments:\n{kwargs}",
411+
UserWarning,
412+
)
403413
return _read_items_scan(
404414
table_name=table_name,
405415
as_dataframe=as_dataframe,
@@ -450,6 +460,11 @@ def read_items( # pylint: disable=too-many-branches
450460
Under the hood, it wraps all the four available read actions: `get_item`, `batch_get_item`,
451461
`query` and `scan`.
452462
463+
Warning
464+
-------
465+
To avoid a potentially costly Scan operation, please make sure to pass arguments such as
466+
`partition_values` or `max_items_evaluated`. Note that `filter_expression` is applied AFTER a Scan
467+
453468
Note
454469
----
455470
Number of Parallel Scan segments is based on the `use_threads` argument.
@@ -581,6 +596,7 @@ def read_items( # pylint: disable=too-many-branches
581596
... )
582597
583598
Reading items matching a FilterExpression expressed with boto3.dynamodb.conditions.Attr
599+
Note that FilterExpression is applied AFTER a Scan operation
584600
585601
>>> import awswrangler as wr
586602
>>> from boto3.dynamodb.conditions import Attr

tests/unit/test_dynamodb.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
from decimal import Decimal
55
from typing import Any, Dict
66

7+
import pyarrow as pa
78
import pytest
89
from boto3.dynamodb.conditions import Attr, Key
910
from botocore.exceptions import ClientError
1011

1112
import awswrangler as wr
1213
import awswrangler.pandas as pd
1314

15+
from .._utils import is_ray_modin
16+
1417
pytestmark = pytest.mark.distributed
1518

1619

@@ -500,3 +503,37 @@ def test_read_items_limited(
500503
if chunked:
501504
df3 = pd.concat(df3)
502505
assert df3.shape == (min(max_items_evaluated, len(df)), len(df.columns))
506+
507+
508+
@pytest.mark.parametrize(
509+
"params",
510+
[
511+
{
512+
"KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}],
513+
"AttributeDefinitions": [{"AttributeName": "id", "AttributeType": "N"}],
514+
}
515+
],
516+
)
517+
@pytest.mark.parametrize("chunked", [False, True])
518+
def test_read_items_schema(params, dynamodb_table: str, chunked: bool):
519+
df = pd.DataFrame(
520+
{
521+
"id": [Decimal("123.4"), Decimal("226.49"), Decimal("320.320"), Decimal("425.0753")],
522+
"word": ["this", "is", "a", "test"],
523+
"char_count": [4, 2, 1, 4],
524+
}
525+
)
526+
wr.dynamodb.put_df(df=df, table_name=dynamodb_table)
527+
528+
if not is_ray_modin:
529+
with pytest.raises(pa.ArrowInvalid):
530+
wr.dynamodb.read_items(table_name=dynamodb_table, allow_full_scan=True)
531+
532+
schema = pa.schema([("id", pa.decimal128(7, 4)), ("word", pa.string()), ("char_count", pa.int8())])
533+
kwargs = {
534+
"table_name": dynamodb_table,
535+
"chunked": chunked,
536+
"pyarrow_additional_kwargs": {"schema": schema},
537+
}
538+
wr.dynamodb.read_items(allow_full_scan=True, **kwargs)
539+
wr.dynamodb.read_items(filter_expression=Attr("id").eq(1), **kwargs)

0 commit comments

Comments
 (0)