@@ -321,6 +321,7 @@ def _resolve_query_without_cache_ctas(
321321 pyarrow_additional_kwargs : dict [str , Any ] | None = None ,
322322 execution_params : list [str ] | None = None ,
323323 dtype_backend : Literal ["numpy_nullable" , "pyarrow" ] = "numpy_nullable" ,
324+ retreive_workgroup_config : bool = True ,
324325) -> pd .DataFrame | Iterator [pd .DataFrame ]:
325326 ctas_query_info : dict [str , str | _QueryMetadata ] = create_ctas_table (
326327 sql = sql ,
@@ -339,6 +340,7 @@ def _resolve_query_without_cache_ctas(
339340 boto3_session = boto3_session ,
340341 params = execution_params ,
341342 paramstyle = "qmark" ,
343+ retreive_workgroup_config = retreive_workgroup_config ,
342344 )
343345 fully_qualified_name : str = f'"{ ctas_query_info ["ctas_database" ]} "."{ ctas_query_info ["ctas_table" ]} "'
344346 ctas_query_metadata = cast (_QueryMetadata , ctas_query_info ["ctas_query_metadata" ])
@@ -379,6 +381,7 @@ def _resolve_query_without_cache_unload(
379381 pyarrow_additional_kwargs : dict [str , Any ] | None = None ,
380382 execution_params : list [str ] | None = None ,
381383 dtype_backend : Literal ["numpy_nullable" , "pyarrow" ] = "numpy_nullable" ,
384+ retreive_workgroup_config : bool = True ,
382385) -> pd .DataFrame | Iterator [pd .DataFrame ]:
383386 query_metadata = _unload (
384387 sql = sql ,
@@ -395,6 +398,7 @@ def _resolve_query_without_cache_unload(
395398 data_source = data_source ,
396399 athena_query_wait_polling_delay = athena_query_wait_polling_delay ,
397400 execution_params = execution_params ,
401+ retreive_workgroup_config = retreive_workgroup_config ,
398402 )
399403 if file_format == "PARQUET" :
400404 return _fetch_parquet_result (
@@ -430,8 +434,11 @@ def _resolve_query_without_cache_regular(
430434 result_reuse_configuration : dict [str , Any ] | None = None ,
431435 dtype_backend : Literal ["numpy_nullable" , "pyarrow" ] = "numpy_nullable" ,
432436 client_request_token : str | None = None ,
437+ retreive_workgroup_config : bool = True ,
433438) -> pd .DataFrame | Iterator [pd .DataFrame ]:
434- wg_config : _WorkGroupConfig = _get_workgroup_config (session = boto3_session , workgroup = workgroup )
439+ wg_config : _WorkGroupConfig = _get_workgroup_config (
440+ session = boto3_session , workgroup = workgroup , retreive_workgroup_config = retreive_workgroup_config
441+ )
435442 s3_output = _get_s3_output (s3_output = s3_output , wg_config = wg_config , boto3_session = boto3_session )
436443 s3_output = s3_output [:- 1 ] if s3_output [- 1 ] == "/" else s3_output
437444 _logger .debug ("Executing sql: %s" , sql )
@@ -496,6 +503,7 @@ def _resolve_query_without_cache( # noqa: PLR0913
496503 result_reuse_configuration : dict [str , Any ] | None = None ,
497504 dtype_backend : Literal ["numpy_nullable" , "pyarrow" ] = "numpy_nullable" ,
498505 client_request_token : str | None = None ,
506+ retreive_workgroup_config : bool = True ,
499507) -> pd .DataFrame | Iterator [pd .DataFrame ]:
500508 """
501509 Execute a query in Athena and returns results as DataFrame, back to `read_sql_query`.
@@ -530,6 +538,7 @@ def _resolve_query_without_cache( # noqa: PLR0913
530538 pyarrow_additional_kwargs = pyarrow_additional_kwargs ,
531539 execution_params = execution_params ,
532540 dtype_backend = dtype_backend ,
541+ retreive_workgroup_config = retreive_workgroup_config ,
533542 )
534543 finally :
535544 catalog .delete_table_if_exists (database = ctas_database or database , table = name , boto3_session = boto3_session )
@@ -558,6 +567,7 @@ def _resolve_query_without_cache( # noqa: PLR0913
558567 pyarrow_additional_kwargs = pyarrow_additional_kwargs ,
559568 execution_params = execution_params ,
560569 dtype_backend = dtype_backend ,
570+ retreive_workgroup_config = retreive_workgroup_config ,
561571 )
562572 return _resolve_query_without_cache_regular (
563573 sql = sql ,
@@ -578,6 +588,7 @@ def _resolve_query_without_cache( # noqa: PLR0913
578588 result_reuse_configuration = result_reuse_configuration ,
579589 dtype_backend = dtype_backend ,
580590 client_request_token = client_request_token ,
591+ retreive_workgroup_config = retreive_workgroup_config ,
581592 )
582593
583594
@@ -596,8 +607,11 @@ def _unload(
596607 data_source : str | None ,
597608 athena_query_wait_polling_delay : float ,
598609 execution_params : list [str ] | None ,
610+ retreive_workgroup_config : bool = True ,
599611) -> _QueryMetadata :
600- wg_config : _WorkGroupConfig = _get_workgroup_config (session = boto3_session , workgroup = workgroup )
612+ wg_config : _WorkGroupConfig = _get_workgroup_config (
613+ session = boto3_session , workgroup = workgroup , retreive_workgroup_config = retreive_workgroup_config
614+ )
601615 s3_output : str = _get_s3_output (s3_output = path , wg_config = wg_config , boto3_session = boto3_session )
602616 s3_output = s3_output [:- 1 ] if s3_output [- 1 ] == "/" else s3_output
603617 # Athena does not enforce a Query Result Location for UNLOAD. Thus, the workgroup output location
@@ -793,6 +807,7 @@ def read_sql_query(
793807 dtype_backend : Literal ["numpy_nullable" , "pyarrow" ] = "numpy_nullable" ,
794808 s3_additional_kwargs : dict [str , Any ] | None = None ,
795809 pyarrow_additional_kwargs : dict [str , Any ] | None = None ,
810+ retreive_workgroup_config : bool = True ,
796811) -> pd .DataFrame | Iterator [pd .DataFrame ]:
797812 """Execute any SQL query on AWS Athena and return the results as a Pandas DataFrame.
798813
@@ -1002,6 +1017,11 @@ def read_sql_query(
10021017 Forwarded to `to_pandas` method converting from PyArrow tables to Pandas DataFrame.
10031018 Valid values include "split_blocks", "self_destruct", "ignore_metadata".
10041019 e.g. pyarrow_additional_kwargs={'split_blocks': True}.
1020+ retreive_workgroup_config
1021+ Indicates whether to use the workgroup configuration for the query execution.
1022+ If True, the workgroup configuration will be retreived and used to determine the s3 output location, encryption, and kms key.
1023+ If False, the s3 output location, encryption, and kms key will not be set and will be determined by the AWS Athena service.
1024+ Default is True.
10051025
10061026 Returns
10071027 -------
@@ -1120,6 +1140,7 @@ def read_sql_query(
11201140 result_reuse_configuration = result_reuse_configuration ,
11211141 dtype_backend = dtype_backend ,
11221142 client_request_token = client_request_token ,
1143+ retreive_workgroup_config = retreive_workgroup_config ,
11231144 )
11241145
11251146
@@ -1386,6 +1407,7 @@ def unload(
13861407 params : dict [str , Any ] | list [str ] | None = None ,
13871408 paramstyle : Literal ["qmark" , "named" ] = "named" ,
13881409 athena_query_wait_polling_delay : float = _QUERY_WAIT_POLLING_DELAY ,
1410+ retreive_workgroup_config : bool = True ,
13891411) -> _QueryMetadata :
13901412 """Write query results from a SELECT statement to the specified data format using UNLOAD.
13911413
@@ -1442,6 +1464,11 @@ def unload(
14421464 - ``qmark``
14431465 athena_query_wait_polling_delay
14441466 Interval in seconds for how often the function will check if the Athena query has completed.
1467+ retreive_workgroup_config
1468+ Indicates whether to use the workgroup configuration for the query execution.
1469+ If True, the workgroup configuration will be retreived and used to determine the s3 output location, encryption, and kms key.
1470+ If False, the s3 output location, encryption, and kms key will not be set and will be determined by the AWS Athena service.
1471+ Default is True.
14451472
14461473 Returns
14471474 -------
@@ -1473,4 +1500,5 @@ def unload(
14731500 boto3_session = boto3_session ,
14741501 data_source = data_source ,
14751502 execution_params = execution_params ,
1503+ retreive_workgroup_config = retreive_workgroup_config ,
14761504 )
0 commit comments