Skip to content

Commit 54a266c

Browse files
committed
Add support for requester pays on s3. #430
1 parent 2fb23f6 commit 54a266c

File tree

14 files changed

+250
-54
lines changed

14 files changed

+250
-54
lines changed

awswrangler/athena/_read.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,17 @@ def _fix_csv_types(df: pd.DataFrame, parse_dates: List[str], binaries: List[str]
7272

7373

7474
def _delete_after_iterate(
75-
dfs: Iterator[pd.DataFrame], paths: List[str], use_threads: bool, boto3_session: boto3.Session
75+
dfs: Iterator[pd.DataFrame],
76+
paths: List[str],
77+
use_threads: bool,
78+
boto3_session: boto3.Session,
79+
s3_additional_kwargs: Optional[Dict[str, str]],
7680
) -> Iterator[pd.DataFrame]:
7781
for df in dfs:
7882
yield df
79-
s3.delete_objects(path=paths, use_threads=use_threads, boto3_session=boto3_session)
83+
s3.delete_objects(
84+
path=paths, use_threads=use_threads, boto3_session=boto3_session, s3_additional_kwargs=s3_additional_kwargs
85+
)
8086

8187

8288
def _prepare_query_string_for_comparison(query_string: str) -> str:
@@ -213,6 +219,7 @@ def _fetch_parquet_result(
213219
chunksize: Optional[int],
214220
use_threads: bool,
215221
boto3_session: boto3.Session,
222+
s3_additional_kwargs: Optional[Dict[str, Any]],
216223
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
217224
ret: Union[pd.DataFrame, Iterator[pd.DataFrame]]
218225
chunked: Union[bool, int] = False if chunksize is None else chunksize
@@ -242,10 +249,21 @@ def _fetch_parquet_result(
242249
_logger.debug("type(ret): %s", type(ret))
243250
if chunked is False:
244251
if keep_files is False:
245-
s3.delete_objects(path=paths_delete, use_threads=use_threads, boto3_session=boto3_session)
252+
s3.delete_objects(
253+
path=paths_delete,
254+
use_threads=use_threads,
255+
boto3_session=boto3_session,
256+
s3_additional_kwargs=s3_additional_kwargs,
257+
)
246258
return ret
247259
if keep_files is False:
248-
return _delete_after_iterate(dfs=ret, paths=paths_delete, use_threads=use_threads, boto3_session=boto3_session)
260+
return _delete_after_iterate(
261+
dfs=ret,
262+
paths=paths_delete,
263+
use_threads=use_threads,
264+
boto3_session=boto3_session,
265+
s3_additional_kwargs=s3_additional_kwargs,
266+
)
249267
return ret
250268

251269

@@ -255,6 +273,7 @@ def _fetch_csv_result(
255273
chunksize: Optional[int],
256274
use_threads: bool,
257275
boto3_session: boto3.Session,
276+
s3_additional_kwargs: Optional[Dict[str, Any]],
258277
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
259278
_chunksize: Optional[int] = chunksize if isinstance(chunksize, int) else None
260279
_logger.debug("_chunksize: %s", _chunksize)
@@ -282,13 +301,22 @@ def _fetch_csv_result(
282301
df = _fix_csv_types(df=ret, parse_dates=query_metadata.parse_dates, binaries=query_metadata.binaries)
283302
df = _apply_query_metadata(df=df, query_metadata=query_metadata)
284303
if keep_files is False:
285-
s3.delete_objects(path=[path, f"{path}.metadata"], use_threads=use_threads, boto3_session=boto3_session)
304+
s3.delete_objects(
305+
path=[path, f"{path}.metadata"],
306+
use_threads=use_threads,
307+
boto3_session=boto3_session,
308+
s3_additional_kwargs=s3_additional_kwargs,
309+
)
286310
return df
287311
dfs = _fix_csv_types_generator(dfs=ret, parse_dates=query_metadata.parse_dates, binaries=query_metadata.binaries)
288312
dfs = _add_query_metadata_generator(dfs=dfs, query_metadata=query_metadata)
289313
if keep_files is False:
290314
return _delete_after_iterate(
291-
dfs=dfs, paths=[path, f"{path}.metadata"], use_threads=use_threads, boto3_session=boto3_session
315+
dfs=dfs,
316+
paths=[path, f"{path}.metadata"],
317+
use_threads=use_threads,
318+
boto3_session=boto3_session,
319+
s3_additional_kwargs=s3_additional_kwargs,
292320
)
293321
return dfs
294322

@@ -299,6 +327,7 @@ def _resolve_query_with_cache(
299327
chunksize: Optional[Union[int, bool]],
300328
use_threads: bool,
301329
session: Optional[boto3.Session],
330+
s3_additional_kwargs: Optional[Dict[str, Any]],
302331
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
303332
"""Fetch cached data and return it as a pandas DataFrame (or list of DataFrames)."""
304333
_logger.debug("cache_info:\n%s", cache_info)
@@ -319,6 +348,7 @@ def _resolve_query_with_cache(
319348
chunksize=chunksize,
320349
use_threads=use_threads,
321350
boto3_session=session,
351+
s3_additional_kwargs=s3_additional_kwargs,
322352
)
323353
if cache_info.file_format == "csv":
324354
return _fetch_csv_result(
@@ -327,6 +357,7 @@ def _resolve_query_with_cache(
327357
chunksize=chunksize,
328358
use_threads=use_threads,
329359
boto3_session=session,
360+
s3_additional_kwargs=s3_additional_kwargs,
330361
)
331362
raise exceptions.InvalidArgumentValue(f"Invalid data type: {cache_info.file_format}.")
332363

@@ -345,6 +376,7 @@ def _resolve_query_without_cache_ctas(
345376
wg_config: _WorkGroupConfig,
346377
name: Optional[str],
347378
use_threads: bool,
379+
s3_additional_kwargs: Optional[Dict[str, Any]],
348380
boto3_session: boto3.Session,
349381
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
350382
path: str = f"{s3_output}/{name}"
@@ -412,6 +444,7 @@ def _resolve_query_without_cache_ctas(
412444
categories=categories,
413445
chunksize=chunksize,
414446
use_threads=use_threads,
447+
s3_additional_kwargs=s3_additional_kwargs,
415448
boto3_session=boto3_session,
416449
)
417450

@@ -429,6 +462,7 @@ def _resolve_query_without_cache_regular(
429462
kms_key: Optional[str],
430463
wg_config: _WorkGroupConfig,
431464
use_threads: bool,
465+
s3_additional_kwargs: Optional[Dict[str, Any]],
432466
boto3_session: boto3.Session,
433467
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
434468
_logger.debug("sql: %s", sql)
@@ -456,6 +490,7 @@ def _resolve_query_without_cache_regular(
456490
chunksize=chunksize,
457491
use_threads=use_threads,
458492
boto3_session=boto3_session,
493+
s3_additional_kwargs=s3_additional_kwargs,
459494
)
460495

461496

@@ -474,6 +509,7 @@ def _resolve_query_without_cache(
474509
keep_files: bool,
475510
ctas_temp_table_name: Optional[str],
476511
use_threads: bool,
512+
s3_additional_kwargs: Optional[Dict[str, Any]],
477513
boto3_session: boto3.Session,
478514
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
479515
"""
@@ -504,6 +540,7 @@ def _resolve_query_without_cache(
504540
wg_config=wg_config,
505541
name=name,
506542
use_threads=use_threads,
543+
s3_additional_kwargs=s3_additional_kwargs,
507544
boto3_session=boto3_session,
508545
)
509546
finally:
@@ -521,6 +558,7 @@ def _resolve_query_without_cache(
521558
kms_key=kms_key,
522559
wg_config=wg_config,
523560
use_threads=use_threads,
561+
s3_additional_kwargs=s3_additional_kwargs,
524562
boto3_session=boto3_session,
525563
)
526564

@@ -546,6 +584,7 @@ def read_sql_query(
546584
max_local_cache_entries: int = 100,
547585
data_source: Optional[str] = None,
548586
params: Optional[Dict[str, Any]] = None,
587+
s3_additional_kwargs: Optional[Dict[str, Any]] = None,
549588
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
550589
"""Execute any SQL query on AWS Athena and return the results as a Pandas DataFrame.
551590
@@ -705,6 +744,9 @@ def read_sql_query(
705744
Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported.
706745
The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain
707746
`:name;`.
747+
s3_additional_kwargs : Optional[Dict[str, Any]]
748+
Forward to botocore requests. Valid parameters: "RequestPayer".
749+
e.g. s3_additional_kwargs={'RequestPayer': 'requester'}
708750
709751
Returns
710752
-------
@@ -761,6 +803,7 @@ def read_sql_query(
761803
chunksize=chunksize,
762804
use_threads=use_threads,
763805
session=session,
806+
s3_additional_kwargs=s3_additional_kwargs,
764807
)
765808
except Exception as e: # pylint: disable=broad-except
766809
_logger.error(e) # if there is anything wrong with the cache, just fallback to the usual path
@@ -779,6 +822,7 @@ def read_sql_query(
779822
keep_files=keep_files,
780823
ctas_temp_table_name=ctas_temp_table_name,
781824
use_threads=use_threads,
825+
s3_additional_kwargs=s3_additional_kwargs,
782826
boto3_session=session,
783827
)
784828

@@ -803,6 +847,7 @@ def read_sql_table(
803847
max_remote_cache_entries: int = 50,
804848
max_local_cache_entries: int = 100,
805849
data_source: Optional[str] = None,
850+
s3_additional_kwargs: Optional[Dict[str, Any]] = None,
806851
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
807852
"""Extract the full table AWS Athena and return the results as a Pandas DataFrame.
808853
@@ -953,6 +998,9 @@ def read_sql_table(
953998
Only takes effect if max_cache_seconds > 0 and default value is 100.
954999
data_source : str, optional
9551000
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
1001+
s3_additional_kwargs : Optional[Dict[str, Any]]
1002+
Forward to botocore requests. Valid parameters: "RequestPayer".
1003+
e.g. s3_additional_kwargs={'RequestPayer': 'requester'}
9561004
9571005
Returns
9581006
-------
@@ -986,6 +1034,7 @@ def read_sql_table(
9861034
max_cache_query_inspections=max_cache_query_inspections,
9871035
max_remote_cache_entries=max_remote_cache_entries,
9881036
max_local_cache_entries=max_local_cache_entries,
1037+
s3_additional_kwargs=s3_additional_kwargs,
9891038
)
9901039

9911040

awswrangler/athena/_utils.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def _fetch_txt_result(
191191
query_metadata: _QueryMetadata,
192192
keep_files: bool,
193193
boto3_session: boto3.Session,
194+
s3_additional_kwargs: Optional[Dict[str, str]],
194195
) -> pd.DataFrame:
195196
if query_metadata.output_location is None or query_metadata.output_location.endswith(".txt") is False:
196197
return pd.DataFrame()
@@ -211,7 +212,12 @@ def _fetch_txt_result(
211212
sep="\t",
212213
)
213214
if keep_files is False:
214-
s3.delete_objects(path=[path, f"{path}.metadata"], use_threads=False, boto3_session=boto3_session)
215+
s3.delete_objects(
216+
path=[path, f"{path}.metadata"],
217+
use_threads=False,
218+
boto3_session=boto3_session,
219+
s3_additional_kwargs=s3_additional_kwargs,
220+
)
215221
return df
216222

217223

@@ -532,6 +538,7 @@ def describe_table(
532538
workgroup: Optional[str] = None,
533539
encryption: Optional[str] = None,
534540
kms_key: Optional[str] = None,
541+
s3_additional_kwargs: Optional[Dict[str, Any]] = None,
535542
boto3_session: Optional[boto3.Session] = None,
536543
) -> pd.DataFrame:
537544
"""Show the list of columns, including partition columns: 'DESCRIBE table;'.
@@ -558,6 +565,9 @@ def describe_table(
558565
None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
559566
kms_key : str, optional
560567
For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
568+
s3_additional_kwargs : Optional[Dict[str, Any]]
569+
Forward to botocore requests. Valid parameters: "RequestPayer".
570+
e.g. s3_additional_kwargs={'RequestPayer': 'requester'}
561571
boto3_session : boto3.Session(), optional
562572
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
563573
@@ -587,9 +597,7 @@ def describe_table(
587597
)
588598
query_metadata: _QueryMetadata = _get_query_metadata(query_execution_id=query_id, boto3_session=session)
589599
raw_result = _fetch_txt_result(
590-
query_metadata=query_metadata,
591-
keep_files=True,
592-
boto3_session=session,
600+
query_metadata=query_metadata, keep_files=True, boto3_session=session, s3_additional_kwargs=s3_additional_kwargs
593601
)
594602
return _parse_describe_table(raw_result)
595603

@@ -602,6 +610,7 @@ def show_create_table(
602610
workgroup: Optional[str] = None,
603611
encryption: Optional[str] = None,
604612
kms_key: Optional[str] = None,
613+
s3_additional_kwargs: Optional[Dict[str, Any]] = None,
605614
boto3_session: Optional[boto3.Session] = None,
606615
) -> str:
607616
"""Generate the query that created it: 'SHOW CREATE TABLE table;'.
@@ -627,6 +636,9 @@ def show_create_table(
627636
None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
628637
kms_key : str, optional
629638
For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
639+
s3_additional_kwargs : Optional[Dict[str, Any]]
640+
Forward to botocore requests. Valid parameters: "RequestPayer".
641+
e.g. s3_additional_kwargs={'RequestPayer': 'requester'}
630642
boto3_session : boto3.Session(), optional
631643
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
632644
@@ -656,9 +668,7 @@ def show_create_table(
656668
)
657669
query_metadata: _QueryMetadata = _get_query_metadata(query_execution_id=query_id, boto3_session=session)
658670
raw_result = _fetch_txt_result(
659-
query_metadata=query_metadata,
660-
keep_files=True,
661-
boto3_session=session,
671+
query_metadata=query_metadata, keep_files=True, boto3_session=session, s3_additional_kwargs=s3_additional_kwargs
662672
)
663673
return cast(str, raw_result.createtab_stmt.str.strip().str.cat(sep=" "))
664674

awswrangler/redshift.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,9 @@ def _read_parquet_iterator(
304304
)
305305
yield from dfs
306306
if keep_files is False:
307-
s3.delete_objects(path=path, use_threads=use_threads, boto3_session=boto3_session)
307+
s3.delete_objects(
308+
path=path, use_threads=use_threads, boto3_session=boto3_session, s3_additional_kwargs=s3_additional_kwargs
309+
)
308310

309311

310312
def connect(
@@ -1012,7 +1014,9 @@ def unload(
10121014
s3_additional_kwargs=s3_additional_kwargs,
10131015
)
10141016
if keep_files is False:
1015-
s3.delete_objects(path=path, use_threads=use_threads, boto3_session=session)
1017+
s3.delete_objects(
1018+
path=path, use_threads=use_threads, boto3_session=session, s3_additional_kwargs=s3_additional_kwargs
1019+
)
10161020
return df
10171021
return _read_parquet_iterator(
10181022
path=path,
@@ -1129,7 +1133,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
11291133
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
11301134
s3_additional_kwargs:
11311135
Forward to botocore requests. Valid parameters: "ACL", "Metadata", "ServerSideEncryption", "StorageClass",
1132-
"SSECustomerAlgorithm", "SSECustomerKey", "SSEKMSKeyId", "SSEKMSEncryptionContext", "Tagging".
1136+
"SSECustomerAlgorithm", "SSECustomerKey", "SSEKMSKeyId", "SSEKMSEncryptionContext", "Tagging", "RequestPayer".
11331137
e.g. s3_additional_kwargs={'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'YOUR_KMS_KEY_ARN'}
11341138
11351139
Returns
@@ -1307,7 +1311,7 @@ def copy( # pylint: disable=too-many-arguments
13071311
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
13081312
s3_additional_kwargs:
13091313
Forward to botocore requests. Valid parameters: "ACL", "Metadata", "ServerSideEncryption", "StorageClass",
1310-
"SSECustomerAlgorithm", "SSECustomerKey", "SSEKMSKeyId", "SSEKMSEncryptionContext", "Tagging".
1314+
"SSECustomerAlgorithm", "SSECustomerKey", "SSEKMSKeyId", "SSEKMSEncryptionContext", "Tagging", "RequestPayer".
13111315
e.g. s3_additional_kwargs={'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'YOUR_KMS_KEY_ARN'}
13121316
max_rows_by_file : int
13131317
Max number of rows in each file.
@@ -1338,7 +1342,7 @@ def copy( # pylint: disable=too-many-arguments
13381342
path = path[:-1] if path.endswith("*") else path
13391343
path = path if path.endswith("/") else f"{path}/"
13401344
session: boto3.Session = _utils.ensure_session(session=boto3_session)
1341-
if s3.list_objects(path=path, boto3_session=session):
1345+
if s3.list_objects(path=path, boto3_session=session, s3_additional_kwargs=s3_additional_kwargs):
13421346
raise exceptions.InvalidArgument(
13431347
f"The received S3 path ({path}) is not empty. "
13441348
"Please, provide a different path or use wr.s3.delete_objects() to clean up the current one."
@@ -1380,4 +1384,6 @@ def copy( # pylint: disable=too-many-arguments
13801384
)
13811385
finally:
13821386
if keep_files is False:
1383-
s3.delete_objects(path=path, use_threads=use_threads, boto3_session=session)
1387+
s3.delete_objects(
1388+
path=path, use_threads=use_threads, boto3_session=session, s3_additional_kwargs=s3_additional_kwargs
1389+
)

0 commit comments

Comments
 (0)