Skip to content

Commit 87d6396

Browse files
committed
Add boto3 session serializer/deserializer on _utils.py.
1 parent f029a3c commit 87d6396

File tree

2 files changed

+52
-15
lines changed

2 files changed

+52
-15
lines changed

awswrangler/_utils.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Internal (private) Utilities Module."""
22

3+
import copy
34
import logging
45
import math
56
import os
67
import random
7-
from typing import Any, Dict, Generator, List, Optional, Tuple
8+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
89

910
import boto3 # type: ignore
1011
import botocore.config # type: ignore
@@ -17,8 +18,10 @@
1718
_logger: logging.Logger = logging.getLogger(__name__)
1819

1920

20-
def ensure_session(session: Optional[boto3.Session] = None) -> boto3.Session:
21+
def ensure_session(session: Optional[Union[boto3.Session, Dict[str, Optional[str]]]] = None) -> boto3.Session:
2122
"""Ensure that a valid boto3.Session will be returned."""
23+
if isinstance(session, dict): # Primitives received
24+
return boto3_from_primitives(primitives=session)
2225
if session is not None:
2326
return session
2427
# Ensure the boto3's default session is used so that its parameters can be
@@ -28,6 +31,30 @@ def ensure_session(session: Optional[boto3.Session] = None) -> boto3.Session:
2831
return boto3.Session() # pragma: no cover
2932

3033

34+
def boto3_to_primitives(boto3_session: Optional[boto3.Session] = None) -> Dict[str, Optional[str]]:
35+
"""Convert Boto3 Session to Python primitives."""
36+
_boto3_session: boto3.Session = ensure_session(session=boto3_session)
37+
credentials = _boto3_session.get_credentials()
38+
return {
39+
"aws_access_key_id": getattr(credentials, "access_key", None),
40+
"aws_secret_access_key": getattr(credentials, "secret_key", None),
41+
"aws_session_token": getattr(credentials, "token", None),
42+
"region_name": _boto3_session.region_name,
43+
"profile_name": _boto3_session.profile_name,
44+
}
45+
46+
47+
def boto3_from_primitives(primitives: Dict[str, Optional[str]] = None) -> boto3.Session:
48+
"""Convert Python primitives to Boto3 Session."""
49+
if primitives is None:
50+
return boto3.DEFAULT_SESSION # pragma: no cover
51+
_primitives: Dict[str, Optional[str]] = copy.deepcopy(primitives)
52+
profile_name: Optional[str] = _primitives.get("profile_name", None)
53+
_primitives["profile_name"] = None if profile_name in (None, "default") else profile_name
54+
args: Dict[str, str] = {k: v for k, v in _primitives.items() if v is not None}
55+
return boto3.Session(**args)
56+
57+
3158
def client(service_name: str, session: Optional[boto3.Session] = None) -> boto3.client:
3259
"""Create a valid boto3.client."""
3360
return ensure_session(session=session).client(
@@ -139,7 +166,8 @@ def chunkify(lst: List[Any], num_chunks: int = 1, max_length: Optional[int] = No
139166

140167

141168
def get_fs(
142-
session: Optional[boto3.Session] = None, s3_additional_kwargs: Optional[Dict[str, str]] = None
169+
session: Optional[Union[boto3.Session, Dict[str, Optional[str]]]] = None,
170+
s3_additional_kwargs: Optional[Dict[str, str]] = None,
143171
) -> s3fs.S3FileSystem:
144172
"""Build a S3FileSystem from a given boto3 session."""
145173
fs: s3fs.S3FileSystem = s3fs.S3FileSystem(

awswrangler/s3.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,7 +1614,7 @@ def _read_text(
16141614
path_root=path_root,
16151615
)
16161616
return dfs
1617-
if (use_threads is False) or (boto3_session is not None):
1617+
if use_threads is False:
16181618
df: pd.DataFrame = pd.concat(
16191619
objs=[
16201620
_read_text_full(
@@ -1640,7 +1640,7 @@ def _read_text(
16401640
repeat(parser_func),
16411641
repeat(path_root),
16421642
paths,
1643-
repeat(None), # Boto3.Session
1643+
repeat(_utils.boto3_to_primitives(boto3_session=session)), # Boto3.Session
16441644
repeat(pandas_kwargs),
16451645
repeat(s3_additional_kwargs),
16461646
repeat(dataset),
@@ -1683,7 +1683,7 @@ def _read_text_full(
16831683
parser_func: Callable,
16841684
path_root: str,
16851685
path: str,
1686-
boto3_session: boto3.Session,
1686+
boto3_session: Union[boto3.Session, Dict[str, Optional[str]]],
16871687
pandas_args: Dict[str, Any],
16881688
s3_additional_kwargs: Optional[Dict[str, str]] = None,
16891689
dataset: bool = False,
@@ -2354,29 +2354,38 @@ def _wait_objects(
23542354
delay = 5 if delay is None else delay
23552355
max_attempts = 20 if max_attempts is None else max_attempts
23562356
_delay: int = int(delay) if isinstance(delay, float) else delay
2357-
23582357
if len(paths) < 1:
23592358
return None
23602359
client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session)
2361-
waiter = client_s3.get_waiter(waiter_name)
23622360
_paths: List[Tuple[str, str]] = [_utils.parse_path(path=p) for p in paths]
23632361
if use_threads is False:
2362+
waiter = client_s3.get_waiter(waiter_name)
23642363
for bucket, key in _paths:
23652364
waiter.wait(Bucket=bucket, Key=key, WaiterConfig={"Delay": _delay, "MaxAttempts": max_attempts})
23662365
else:
23672366
cpus: int = _utils.ensure_cpu_count(use_threads=use_threads)
23682367
with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor:
2369-
futures: List[concurrent.futures.Future] = []
2370-
for bucket, key in _paths:
2371-
future: concurrent.futures.Future = executor.submit(
2372-
fn=waiter.wait, Bucket=bucket, Key=key, WaiterConfig={"Delay": _delay, "MaxAttempts": max_attempts}
2368+
list(
2369+
executor.map(
2370+
_wait_objects_concurrent,
2371+
_paths,
2372+
repeat(waiter_name),
2373+
repeat(client_s3),
2374+
repeat(_delay),
2375+
repeat(max_attempts),
23732376
)
2374-
futures.append(future)
2375-
for future in futures:
2376-
future.result()
2377+
)
23772378
return None
23782379

23792380

2381+
def _wait_objects_concurrent(
2382+
path: Tuple[str, str], waiter_name: str, client_s3: boto3.client, delay: int, max_attempts: int
2383+
) -> None:
2384+
waiter = client_s3.get_waiter(waiter_name)
2385+
bucket, key = path
2386+
waiter.wait(Bucket=bucket, Key=key, WaiterConfig={"Delay": delay, "MaxAttempts": max_attempts})
2387+
2388+
23802389
def read_parquet_table(
23812390
table: str,
23822391
database: str,

0 commit comments

Comments
 (0)