Skip to content

Commit cdefe04

Browse files
refactor: Remove usage of boto3 resources (#2525)
1 parent 28880d5 commit cdefe04

File tree

9 files changed

+480
-220
lines changed

9 files changed

+480
-220
lines changed

awswrangler/dynamodb/_delete.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from typing import Any, Dict, List, Optional
55

66
import boto3
7+
from boto3.dynamodb.types import TypeSerializer
78

9+
from awswrangler import _utils
810
from awswrangler._config import apply_configs
911

10-
from ._utils import _validate_items, get_table
12+
from ._utils import _TableBatchWriter, _validate_items
1113

1214
_logger: logging.Logger = logging.getLogger(__name__)
1315

@@ -46,9 +48,16 @@ def delete_items(
4648
"""
4749
_logger.debug("Deleting items from DynamoDB table %s", table_name)
4850

49-
dynamodb_table = get_table(table_name=table_name, boto3_session=boto3_session)
50-
_validate_items(items=items, dynamodb_table=dynamodb_table)
51-
table_keys = [schema["AttributeName"] for schema in dynamodb_table.key_schema]
52-
with dynamodb_table.batch_writer() as writer:
51+
dynamodb_client = _utils.client(service_name="dynamodb", session=boto3_session)
52+
serializer = TypeSerializer()
53+
54+
key_schema = dynamodb_client.describe_table(TableName=table_name)["Table"]["KeySchema"]
55+
_validate_items(items=items, key_schema=key_schema)
56+
57+
table_keys = [schema["AttributeName"] for schema in key_schema]
58+
59+
with _TableBatchWriter(table_name, dynamodb_client) as writer:
5360
for item in items:
54-
writer.delete_item(Key={key: item[key] for key in table_keys})
61+
writer.delete_item(
62+
key={key: serializer.serialize(item[key]) for key in table_keys},
63+
)

awswrangler/dynamodb/_read.py

Lines changed: 86 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Dict,
1212
Iterator,
1313
List,
14+
NamedTuple,
1415
Optional,
1516
Sequence,
1617
TypeVar,
@@ -20,8 +21,8 @@
2021

2122
import boto3
2223
import pyarrow as pa
23-
from boto3.dynamodb.conditions import ConditionBase
24-
from boto3.dynamodb.types import Binary
24+
from boto3.dynamodb.conditions import ConditionBase, ConditionExpressionBuilder
25+
from boto3.dynamodb.types import Binary, TypeDeserializer, TypeSerializer
2526
from botocore.exceptions import ClientError
2627
from typing_extensions import Literal
2728

@@ -30,7 +31,7 @@
3031
from awswrangler._distributed import engine
3132
from awswrangler._executor import _BaseExecutor, _get_executor
3233
from awswrangler.distributed.ray import ray_get
33-
from awswrangler.dynamodb._utils import _serialize_kwargs, execute_statement, get_table
34+
from awswrangler.dynamodb._utils import _deserialize_item, _serialize_item, execute_statement
3435

3536
if TYPE_CHECKING:
3637
from mypy_boto3_dynamodb.client import DynamoDBClient
@@ -195,8 +196,8 @@ def _read_scan_chunked(
195196
# SEE: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Scan.html#Scan.ParallelScan
196197
client_dynamodb = dynamodb_client if dynamodb_client else _utils.client(service_name="dynamodb")
197198

198-
deserializer = boto3.dynamodb.types.TypeDeserializer()
199-
next_token = "init_token" # Dummy token
199+
deserializer = TypeDeserializer()
200+
next_token: Optional[str] = "init_token" # Dummy token
200201
total_items = 0
201202

202203
kwargs = dict(kwargs)
@@ -218,7 +219,7 @@ def _read_scan_chunked(
218219
if ("Limit" in kwargs) and (total_items >= kwargs["Limit"]):
219220
break
220221

221-
next_token = response.get("LastEvaluatedKey", None) # type: ignore[assignment]
222+
next_token = response.get("LastEvaluatedKey", None)
222223
if next_token:
223224
kwargs["ExclusiveStartKey"] = next_token
224225

@@ -242,33 +243,30 @@ def _read_scan(
242243
return _utils.list_to_arrow_table(mapping=items, schema=schema) if as_dataframe else items
243244

244245

245-
def _read_query_chunked(
246-
table_name: str, boto3_session: Optional[boto3.Session] = None, **kwargs: Any
247-
) -> Iterator[_ItemsListType]:
248-
table = get_table(table_name=table_name, boto3_session=boto3_session)
249-
next_token = "init_token" # Dummy token
246+
def _read_query_chunked(table_name: str, dynamodb_client: "DynamoDBClient", **kwargs: Any) -> Iterator[_ItemsListType]:
247+
next_token: Optional[str] = "init_token" # Dummy token
250248
total_items = 0
251249

252250
# Handle pagination
253251
while next_token:
254-
response = table.query(**kwargs)
252+
response = dynamodb_client.query(TableName=table_name, **kwargs)
255253
items = response.get("Items", [])
256254
total_items += len(items)
257255
yield items
258256

259257
if ("Limit" in kwargs) and (total_items >= kwargs["Limit"]):
260258
break
261259

262-
next_token = response.get("LastEvaluatedKey", None) # type: ignore[assignment]
260+
next_token = response.get("LastEvaluatedKey", None)
263261
if next_token:
264262
kwargs["ExclusiveStartKey"] = next_token
265263

266264

267265
@_handle_reserved_keyword_error
268266
def _read_query(
269-
table_name: str, chunked: bool, boto3_session: Optional[boto3.Session] = None, **kwargs: Any
267+
table_name: str, dynamodb_client: "DynamoDBClient", chunked: bool, **kwargs: Any
270268
) -> Union[_ItemsListType, Iterator[_ItemsListType]]:
271-
items_iterator = _read_query_chunked(table_name, boto3_session, **kwargs)
269+
items_iterator = _read_query_chunked(table_name, dynamodb_client, **kwargs)
272270

273271
if chunked:
274272
return items_iterator
@@ -277,12 +275,13 @@ def _read_query(
277275

278276

279277
def _read_batch_items_chunked(
280-
table_name: str, boto3_session: Optional[boto3.Session] = None, **kwargs: Any
278+
table_name: str, dynamodb_client: Optional["DynamoDBClient"], **kwargs: Any
281279
) -> Iterator[_ItemsListType]:
282-
resource = _utils.resource(service_name="dynamodb", session=boto3_session)
280+
dynamodb_client = dynamodb_client if dynamodb_client else _utils.client("dynamodb")
281+
deserializer = TypeDeserializer()
283282

284-
response = resource.batch_get_item(RequestItems={table_name: kwargs}) # type: ignore[dict-item]
285-
yield response.get("Responses", {table_name: []}).get(table_name, []) # type: ignore[arg-type]
283+
response = dynamodb_client.batch_get_item(RequestItems={table_name: kwargs})
284+
yield [_deserialize_item(d, deserializer) for d in response.get("Responses", {table_name: []}).get(table_name, [])]
286285

287286
# SEE: handle possible unprocessed keys. As suggested in Boto3 docs,
288287
# this approach should involve exponential backoff, but this should be
@@ -291,15 +290,17 @@ def _read_batch_items_chunked(
291290
while response["UnprocessedKeys"]:
292291
kwargs["Keys"] = response["UnprocessedKeys"][table_name]["Keys"]
293292

294-
response = resource.batch_get_item(RequestItems={table_name: kwargs}) # type: ignore[dict-item]
295-
yield response.get("Responses", {table_name: []}).get(table_name, []) # type: ignore[arg-type]
293+
response = dynamodb_client.batch_get_item(RequestItems={table_name: kwargs})
294+
yield [
295+
_deserialize_item(d, deserializer) for d in response.get("Responses", {table_name: []}).get(table_name, [])
296+
]
296297

297298

298299
@_handle_reserved_keyword_error
299300
def _read_batch_items(
300-
table_name: str, chunked: bool, boto3_session: Optional[boto3.Session] = None, **kwargs: Any
301+
table_name: str, dynamodb_client: Optional["DynamoDBClient"], chunked: bool, **kwargs: Any
301302
) -> Union[_ItemsListType, Iterator[_ItemsListType]]:
302-
items_iterator = _read_batch_items_chunked(table_name, boto3_session, **kwargs)
303+
items_iterator = _read_batch_items_chunked(table_name, dynamodb_client, **kwargs)
303304

304305
if chunked:
305306
return items_iterator
@@ -309,10 +310,13 @@ def _read_batch_items(
309310

310311
@_handle_reserved_keyword_error
311312
def _read_item(
312-
table_name: str, chunked: bool = False, boto3_session: Optional[boto3.Session] = None, **kwargs: Any
313+
table_name: str,
314+
dynamodb_client: "DynamoDBClient",
315+
chunked: bool = False,
316+
**kwargs: Any,
313317
) -> Union[_ItemsListType, Iterator[_ItemsListType]]:
314-
table = get_table(table_name=table_name, boto3_session=boto3_session)
315-
item_list: _ItemsListType = [table.get_item(**kwargs).get("Item", {})]
318+
item = dynamodb_client.get_item(TableName=table_name, **kwargs).get("Item", {})
319+
item_list: _ItemsListType = [_deserialize_item(item)]
316320

317321
return [item_list] if chunked else item_list
318322

@@ -322,13 +326,10 @@ def _read_items_scan(
322326
as_dataframe: bool,
323327
arrow_kwargs: Dict[str, Any],
324328
use_threads: Union[bool, int],
329+
dynamodb_client: "DynamoDBClient",
325330
chunked: bool,
326-
boto3_session: Optional[boto3.Session] = None,
327331
**kwargs: Any,
328332
) -> Union[pd.DataFrame, Iterator[pd.DataFrame], _ItemsListType, Iterator[_ItemsListType]]:
329-
dynamodb_client = _utils.client(service_name="dynamodb", session=boto3_session)
330-
331-
kwargs = _serialize_kwargs(kwargs)
332333
kwargs["TableName"] = table_name
333334
schema = arrow_kwargs.pop("schema", None)
334335

@@ -368,7 +369,7 @@ def _read_items(
368369
arrow_kwargs: Dict[str, Any],
369370
use_threads: Union[bool, int],
370371
chunked: bool,
371-
boto3_session: Optional[boto3.Session] = None,
372+
dynamodb_client: "DynamoDBClient",
372373
**kwargs: Any,
373374
) -> Union[pd.DataFrame, Iterator[pd.DataFrame], _ItemsListType, Iterator[_ItemsListType]]:
374375
# Extract 'Keys', 'IndexName' and 'Limit' from provided kwargs: if needed, will be reinserted later on
@@ -384,12 +385,12 @@ def _read_items(
384385
# Single Item
385386
if use_get_item:
386387
kwargs["Key"] = keys[0]
387-
items = _read_item(table_name, chunked, boto3_session, **kwargs)
388+
items = _read_item(table_name, dynamodb_client, chunked, **kwargs)
388389

389390
# Batch of Items
390391
elif use_batch_get_item:
391392
kwargs["Keys"] = keys
392-
items = _read_batch_items(table_name, chunked, boto3_session, **kwargs)
393+
items = _read_batch_items(table_name, dynamodb_client, chunked, **kwargs)
393394

394395
else:
395396
if limit:
@@ -403,7 +404,7 @@ def _read_items(
403404
if use_query:
404405
# Query
405406
_logger.debug("Query DynamoDB table %s", table_name)
406-
items = _read_query(table_name, chunked, boto3_session, **kwargs)
407+
items = _read_query(table_name, dynamodb_client, chunked, **kwargs)
407408
else:
408409
# Last resort use Scan
409410
warnings.warn(
@@ -415,8 +416,8 @@ def _read_items(
415416
as_dataframe=as_dataframe,
416417
arrow_kwargs=arrow_kwargs,
417418
use_threads=use_threads,
419+
dynamodb_client=dynamodb_client,
418420
chunked=chunked,
419-
boto3_session=boto3_session,
420421
**kwargs,
421422
)
422423

@@ -428,6 +429,25 @@ def _read_items(
428429
return _convert_items(items=cast(_ItemsListType, items), as_dataframe=as_dataframe, arrow_kwargs=arrow_kwargs)
429430

430431

432+
class _ExpressionTuple(NamedTuple):
433+
condition_expression: str
434+
attribute_name_placeholders: Dict[str, str]
435+
attribute_value_placeholders: Dict[str, Any]
436+
437+
438+
def _convert_condition_base_to_expression(
439+
key_condition_expression: ConditionBase, is_key_condition: bool, serializer: TypeSerializer
440+
) -> Dict[str, Any]:
441+
builder = ConditionExpressionBuilder()
442+
expression = builder.build_expression(key_condition_expression, is_key_condition=is_key_condition)
443+
444+
return _ExpressionTuple(
445+
condition_expression=expression.condition_expression,
446+
attribute_name_placeholders=expression.attribute_name_placeholders,
447+
attribute_value_placeholders=_serialize_item(expression.attribute_value_placeholders, serializer=serializer),
448+
)
449+
450+
431451
@_utils.validate_distributed_kwargs(
432452
unsupported_kwargs=["boto3_session", "dtype_backend"],
433453
)
@@ -630,7 +650,9 @@ def read_items( # pylint: disable=too-many-branches
630650
)
631651

632652
# Extract key schema
633-
table_key_schema = get_table(table_name=table_name, boto3_session=boto3_session).key_schema
653+
dynamodb_client = _utils.client(service_name="dynamodb", session=boto3_session)
654+
serializer = TypeSerializer()
655+
table_key_schema = dynamodb_client.describe_table(TableName=table_name)["Table"]["KeySchema"]
634656

635657
# Detect sort key, if any
636658
if len(table_key_schema) == 1:
@@ -645,28 +667,50 @@ def read_items( # pylint: disable=too-many-branches
645667
kwargs: Dict[str, Any] = {"ConsistentRead": consistent}
646668
if partition_values:
647669
if sort_key is None:
648-
keys = [{partition_key: pv} for pv in partition_values]
670+
keys = [{partition_key: serializer.serialize(pv)} for pv in partition_values]
649671
else:
650672
if not sort_values:
651673
raise exceptions.InvalidArgumentType(
652674
f"Kwarg sort_values must be specified: table {table_name} has {sort_key} as sort key."
653675
)
654676
if len(sort_values) != len(partition_values):
655677
raise exceptions.InvalidArgumentCombination("Partition and sort values must have the same length.")
656-
keys = [{partition_key: pv, sort_key: sv} for pv, sv in zip(partition_values, sort_values)]
678+
keys = [
679+
{partition_key: serializer.serialize(pv), sort_key: serializer.serialize(sv)}
680+
for pv, sv in zip(partition_values, sort_values)
681+
]
657682
kwargs["Keys"] = keys
658683
if index_name:
659684
kwargs["IndexName"] = index_name
685+
660686
if key_condition_expression:
661-
kwargs["KeyConditionExpression"] = key_condition_expression
687+
if isinstance(key_condition_expression, str):
688+
kwargs["KeyConditionExpression"] = key_condition_expression
689+
else:
690+
expression_tuple = _convert_condition_base_to_expression(
691+
key_condition_expression, is_key_condition=True, serializer=serializer
692+
)
693+
kwargs["KeyConditionExpression"] = expression_tuple.condition_expression
694+
kwargs["ExpressionAttributeNames"] = expression_tuple.attribute_name_placeholders
695+
kwargs["ExpressionAttributeValues"] = expression_tuple.attribute_value_placeholders
696+
662697
if filter_expression:
663-
kwargs["FilterExpression"] = filter_expression
698+
if isinstance(filter_expression, str):
699+
kwargs["FilterExpression"] = filter_expression
700+
else:
701+
expression_tuple = _convert_condition_base_to_expression(
702+
filter_expression, is_key_condition=False, serializer=serializer
703+
)
704+
kwargs["FilterExpression"] = expression_tuple.condition_expression
705+
kwargs["ExpressionAttributeNames"] = expression_tuple.attribute_name_placeholders
706+
kwargs["ExpressionAttributeValues"] = expression_tuple.attribute_value_placeholders
707+
664708
if columns:
665709
kwargs["ProjectionExpression"] = ", ".join(columns)
666710
if expression_attribute_names:
667711
kwargs["ExpressionAttributeNames"] = expression_attribute_names
668712
if expression_attribute_values:
669-
kwargs["ExpressionAttributeValues"] = expression_attribute_values
713+
kwargs["ExpressionAttributeValues"] = _serialize_item(expression_attribute_values, serializer)
670714
if max_items_evaluated:
671715
kwargs["Limit"] = max_items_evaluated
672716

@@ -678,8 +722,8 @@ def read_items( # pylint: disable=too-many-branches
678722
as_dataframe=as_dataframe,
679723
arrow_kwargs=arrow_kwargs,
680724
use_threads=use_threads,
681-
boto3_session=boto3_session,
682725
chunked=chunked,
726+
dynamodb_client=dynamodb_client,
683727
**kwargs,
684728
)
685729
# Raise otherwise

awswrangler/dynamodb/_read.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def _read_scan(
2626
dynamodb_client: Optional["DynamoDBClient"],
2727
as_dataframe: bool,
2828
kwargs: Dict[str, Any],
29+
schema: Optional[pa.Schema],
2930
segment: int,
3031
) -> Union[pa.Table, _ItemsListType]: ...
3132
@overload

0 commit comments

Comments
 (0)