Skip to content

Commit c373474

Browse files
committed
Add an argument to _get_workgroup_config and all functions using _get_workgroup_config to determine if the workgroup config should be retreived from the server
1 parent 49fbd1c commit c373474

File tree

6 files changed

+118
-10
lines changed

6 files changed

+118
-10
lines changed

awswrangler/athena/_executions.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def start_query_execution(
4747
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
4848
data_source: str | None = None,
4949
wait: bool = False,
50+
retreive_workgroup_config: bool = True,
5051
) -> str | dict[str, Any]:
5152
"""Start a SQL Query against AWS Athena.
5253
@@ -114,6 +115,11 @@ def start_query_execution(
114115
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
115116
wait
116117
Indicates whether to wait for the query to finish and return a dictionary with the query execution response.
118+
retreive_workgroup_config
119+
Indicates whether to use the workgroup configuration for the query execution.
120+
If True, the workgroup configuration will be retreived and used to determine the s3 output location, encryption, and kms key.
121+
If False, the s3 output location, encryption, and kms key will not be set and will be determined by the AWS Athena service.
122+
Default is True.
117123
118124
Returns
119125
-------
@@ -149,7 +155,9 @@ def start_query_execution(
149155
query_execution_id = cache_info.query_execution_id
150156
_logger.debug("Valid cache found. Retrieving...")
151157
else:
152-
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
158+
wg_config: _WorkGroupConfig = _get_workgroup_config(
159+
session=boto3_session, workgroup=workgroup, retreive_workgroup_config=retreive_workgroup_config
160+
)
153161
query_execution_id = _start_query_execution(
154162
sql=sql,
155163
wg_config=wg_config,

awswrangler/athena/_executions.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def start_query_execution(
2424
athena_query_wait_polling_delay: float = ...,
2525
data_source: str | None = ...,
2626
wait: Literal[False] = ...,
27+
retreive_workgroup_config: bool = ...,
2728
) -> str: ...
2829
@overload
2930
def start_query_execution(
@@ -42,6 +43,7 @@ def start_query_execution(
4243
athena_query_wait_polling_delay: float = ...,
4344
data_source: str | None = ...,
4445
wait: Literal[True],
46+
retreive_workgroup_config: bool = ...,
4547
) -> dict[str, Any]: ...
4648
@overload
4749
def start_query_execution(
@@ -60,6 +62,7 @@ def start_query_execution(
6062
athena_query_wait_polling_delay: float = ...,
6163
data_source: str | None = ...,
6264
wait: bool,
65+
retreive_workgroup_config: bool = ...,
6366
) -> str | dict[str, Any]: ...
6467
def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = ...) -> None: ...
6568
def wait_query(

awswrangler/athena/_read.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def _resolve_query_without_cache_ctas(
321321
pyarrow_additional_kwargs: dict[str, Any] | None = None,
322322
execution_params: list[str] | None = None,
323323
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
324+
retreive_workgroup_config: bool = True,
324325
) -> pd.DataFrame | Iterator[pd.DataFrame]:
325326
ctas_query_info: dict[str, str | _QueryMetadata] = create_ctas_table(
326327
sql=sql,
@@ -339,6 +340,7 @@ def _resolve_query_without_cache_ctas(
339340
boto3_session=boto3_session,
340341
params=execution_params,
341342
paramstyle="qmark",
343+
retreive_workgroup_config=retreive_workgroup_config,
342344
)
343345
fully_qualified_name: str = f'"{ctas_query_info["ctas_database"]}"."{ctas_query_info["ctas_table"]}"'
344346
ctas_query_metadata = cast(_QueryMetadata, ctas_query_info["ctas_query_metadata"])
@@ -379,6 +381,7 @@ def _resolve_query_without_cache_unload(
379381
pyarrow_additional_kwargs: dict[str, Any] | None = None,
380382
execution_params: list[str] | None = None,
381383
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
384+
retreive_workgroup_config: bool = True,
382385
) -> pd.DataFrame | Iterator[pd.DataFrame]:
383386
query_metadata = _unload(
384387
sql=sql,
@@ -395,6 +398,7 @@ def _resolve_query_without_cache_unload(
395398
data_source=data_source,
396399
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
397400
execution_params=execution_params,
401+
retreive_workgroup_config=retreive_workgroup_config,
398402
)
399403
if file_format == "PARQUET":
400404
return _fetch_parquet_result(
@@ -430,8 +434,11 @@ def _resolve_query_without_cache_regular(
430434
result_reuse_configuration: dict[str, Any] | None = None,
431435
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
432436
client_request_token: str | None = None,
437+
retreive_workgroup_config: bool = True,
433438
) -> pd.DataFrame | Iterator[pd.DataFrame]:
434-
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
439+
wg_config: _WorkGroupConfig = _get_workgroup_config(
440+
session=boto3_session, workgroup=workgroup, retreive_workgroup_config=retreive_workgroup_config
441+
)
435442
s3_output = _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=boto3_session)
436443
s3_output = s3_output[:-1] if s3_output[-1] == "/" else s3_output
437444
_logger.debug("Executing sql: %s", sql)
@@ -496,6 +503,7 @@ def _resolve_query_without_cache( # noqa: PLR0913
496503
result_reuse_configuration: dict[str, Any] | None = None,
497504
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
498505
client_request_token: str | None = None,
506+
retreive_workgroup_config: bool = True,
499507
) -> pd.DataFrame | Iterator[pd.DataFrame]:
500508
"""
501509
Execute a query in Athena and returns results as DataFrame, back to `read_sql_query`.
@@ -530,6 +538,7 @@ def _resolve_query_without_cache( # noqa: PLR0913
530538
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
531539
execution_params=execution_params,
532540
dtype_backend=dtype_backend,
541+
retreive_workgroup_config=retreive_workgroup_config,
533542
)
534543
finally:
535544
catalog.delete_table_if_exists(database=ctas_database or database, table=name, boto3_session=boto3_session)
@@ -558,6 +567,7 @@ def _resolve_query_without_cache( # noqa: PLR0913
558567
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
559568
execution_params=execution_params,
560569
dtype_backend=dtype_backend,
570+
retreive_workgroup_config=retreive_workgroup_config,
561571
)
562572
return _resolve_query_without_cache_regular(
563573
sql=sql,
@@ -578,6 +588,7 @@ def _resolve_query_without_cache( # noqa: PLR0913
578588
result_reuse_configuration=result_reuse_configuration,
579589
dtype_backend=dtype_backend,
580590
client_request_token=client_request_token,
591+
retreive_workgroup_config=retreive_workgroup_config,
581592
)
582593

583594

@@ -596,8 +607,11 @@ def _unload(
596607
data_source: str | None,
597608
athena_query_wait_polling_delay: float,
598609
execution_params: list[str] | None,
610+
retreive_workgroup_config: bool = True,
599611
) -> _QueryMetadata:
600-
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
612+
wg_config: _WorkGroupConfig = _get_workgroup_config(
613+
session=boto3_session, workgroup=workgroup, retreive_workgroup_config=retreive_workgroup_config
614+
)
601615
s3_output: str = _get_s3_output(s3_output=path, wg_config=wg_config, boto3_session=boto3_session)
602616
s3_output = s3_output[:-1] if s3_output[-1] == "/" else s3_output
603617
# Athena does not enforce a Query Result Location for UNLOAD. Thus, the workgroup output location
@@ -793,6 +807,7 @@ def read_sql_query(
793807
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
794808
s3_additional_kwargs: dict[str, Any] | None = None,
795809
pyarrow_additional_kwargs: dict[str, Any] | None = None,
810+
retreive_workgroup_config: bool = True,
796811
) -> pd.DataFrame | Iterator[pd.DataFrame]:
797812
"""Execute any SQL query on AWS Athena and return the results as a Pandas DataFrame.
798813
@@ -1002,6 +1017,11 @@ def read_sql_query(
10021017
Forwarded to `to_pandas` method converting from PyArrow tables to Pandas DataFrame.
10031018
Valid values include "split_blocks", "self_destruct", "ignore_metadata".
10041019
e.g. pyarrow_additional_kwargs={'split_blocks': True}.
1020+
retreive_workgroup_config
1021+
Indicates whether to use the workgroup configuration for the query execution.
1022+
If True, the workgroup configuration will be retreived and used to determine the s3 output location, encryption, and kms key.
1023+
If False, the s3 output location, encryption, and kms key will not be set and will be determined by the AWS Athena service.
1024+
Default is True.
10051025
10061026
Returns
10071027
-------
@@ -1120,6 +1140,7 @@ def read_sql_query(
11201140
result_reuse_configuration=result_reuse_configuration,
11211141
dtype_backend=dtype_backend,
11221142
client_request_token=client_request_token,
1143+
retreive_workgroup_config=retreive_workgroup_config,
11231144
)
11241145

11251146

@@ -1386,6 +1407,7 @@ def unload(
13861407
params: dict[str, Any] | list[str] | None = None,
13871408
paramstyle: Literal["qmark", "named"] = "named",
13881409
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
1410+
retreive_workgroup_config: bool = True,
13891411
) -> _QueryMetadata:
13901412
"""Write query results from a SELECT statement to the specified data format using UNLOAD.
13911413
@@ -1442,6 +1464,11 @@ def unload(
14421464
- ``qmark``
14431465
athena_query_wait_polling_delay
14441466
Interval in seconds for how often the function will check if the Athena query has completed.
1467+
retreive_workgroup_config
1468+
Indicates whether to use the workgroup configuration for the query execution.
1469+
If True, the workgroup configuration will be retreived and used to determine the s3 output location, encryption, and kms key.
1470+
If False, the s3 output location, encryption, and kms key will not be set and will be determined by the AWS Athena service.
1471+
Default is True.
14451472
14461473
Returns
14471474
-------
@@ -1473,4 +1500,5 @@ def unload(
14731500
boto3_session=boto3_session,
14741501
data_source=data_source,
14751502
execution_params=execution_params,
1503+
retreive_workgroup_config=retreive_workgroup_config,
14761504
)

awswrangler/athena/_read.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def read_sql_query(
7878
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
7979
s3_additional_kwargs: dict[str, Any] | None = ...,
8080
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
81+
retreive_workgroup_config: bool = ...,
8182
) -> pd.DataFrame: ...
8283
@overload
8384
def read_sql_query(
@@ -105,6 +106,7 @@ def read_sql_query(
105106
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
106107
s3_additional_kwargs: dict[str, Any] | None = ...,
107108
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
109+
retreive_workgroup_config: bool = ...,
108110
) -> Iterator[pd.DataFrame]: ...
109111
@overload
110112
def read_sql_query(
@@ -132,6 +134,7 @@ def read_sql_query(
132134
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
133135
s3_additional_kwargs: dict[str, Any] | None = ...,
134136
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
137+
retreive_workgroup_config: bool = ...,
135138
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
136139
@overload
137140
def read_sql_query(
@@ -159,6 +162,7 @@ def read_sql_query(
159162
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
160163
s3_additional_kwargs: dict[str, Any] | None = ...,
161164
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
165+
retreive_workgroup_config: bool = ...,
162166
) -> Iterator[pd.DataFrame]: ...
163167
@overload
164168
def read_sql_query(
@@ -186,6 +190,7 @@ def read_sql_query(
186190
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
187191
s3_additional_kwargs: dict[str, Any] | None = ...,
188192
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
193+
retreive_workgroup_config: bool = ...,
189194
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
190195
@overload
191196
def read_sql_table(
@@ -210,6 +215,7 @@ def read_sql_table(
210215
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
211216
s3_additional_kwargs: dict[str, Any] | None = ...,
212217
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
218+
retreive_workgroup_config: bool = ...,
213219
) -> pd.DataFrame: ...
214220
@overload
215221
def read_sql_table(
@@ -234,6 +240,7 @@ def read_sql_table(
234240
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
235241
s3_additional_kwargs: dict[str, Any] | None = ...,
236242
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
243+
retreive_workgroup_config: bool = ...,
237244
) -> Iterator[pd.DataFrame]: ...
238245
@overload
239246
def read_sql_table(
@@ -258,6 +265,7 @@ def read_sql_table(
258265
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
259266
s3_additional_kwargs: dict[str, Any] | None = ...,
260267
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
268+
retreive_workgroup_config: bool = ...,
261269
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
262270
@overload
263271
def read_sql_table(
@@ -282,6 +290,7 @@ def read_sql_table(
282290
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
283291
s3_additional_kwargs: dict[str, Any] | None = ...,
284292
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
293+
retreive_workgroup_config: bool = ...,
285294
) -> Iterator[pd.DataFrame]: ...
286295
@overload
287296
def read_sql_table(
@@ -306,6 +315,7 @@ def read_sql_table(
306315
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
307316
s3_additional_kwargs: dict[str, Any] | None = ...,
308317
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
318+
retreive_workgroup_config: bool = ...,
309319
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
310320
def unload(
311321
sql: str,
@@ -323,4 +333,5 @@ def unload(
323333
params: dict[str, Any] | list[str] | None = ...,
324334
paramstyle: Literal["qmark", "named"] = ...,
325335
athena_query_wait_polling_delay: float = ...,
336+
retreive_workgroup_config: bool = ...,
326337
) -> _QueryMetadata: ...

0 commit comments

Comments
 (0)