Skip to content

Commit 4b79b76

Browse files
committed
s3_output: where the query can use the workgroup setting, only set s3_output if passed by user. For queries that require a location, keep previous impl of s3_output either by user or by workgroup setting
1 parent 0d10c7f commit 4b79b76

File tree

2 files changed

+39
-14
lines changed

2 files changed

+39
-14
lines changed

awswrangler/athena/_read.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_get_default_workgroup_config,
2626
_get_query_metadata,
2727
_get_s3_output,
28+
_get_workgroup_config,
2829
_QueryMetadata,
2930
_start_query_execution,
3031
_WorkGroupConfig,
@@ -432,8 +433,7 @@ def _resolve_query_without_cache_regular(
432433
client_request_token: str | None = None,
433434
) -> pd.DataFrame | Iterator[pd.DataFrame]:
434435
wg_config: _WorkGroupConfig = _get_default_workgroup_config()
435-
s3_output = _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=boto3_session)
436-
s3_output = s3_output[:-1] if s3_output[-1] == "/" else s3_output
436+
s3_output = s3_output[:-1] if s3_output and s3_output[-1] == "/" else s3_output
437437
_logger.debug("Executing sql: %s", sql)
438438
query_id: str = _start_query_execution(
439439
sql=sql,
@@ -597,13 +597,13 @@ def _unload(
597597
athena_query_wait_polling_delay: float,
598598
execution_params: list[str] | None,
599599
) -> _QueryMetadata:
600-
wg_config: _WorkGroupConfig = _get_default_workgroup_config()
600+
wg_config: _WorkGroupConfig = _get_workgroup_config(workgroup=workgroup, session=boto3_session)
601601
s3_output: str = _get_s3_output(s3_output=path, wg_config=wg_config, boto3_session=boto3_session)
602-
s3_output = s3_output[:-1] if s3_output[-1] == "/" else s3_output
603-
# Athena does not enforce a Query Result Location for UNLOAD. Thus, the workgroup output location
604-
# is only used if no path is supplied.
605-
if not path:
606-
path = s3_output
602+
s3_output = s3_output[:-1] if s3_output and s3_output[-1] == "/" else s3_output
603+
if not s3_output:
604+
raise exceptions.InvalidArgumentValue(
605+
"Output S3 location is required for UNLOAD, either as the path argument or as a workgroup configuration"
606+
)
607607

608608
# Set UNLOAD parameters
609609
unload_parameters = f" format='{file_format}'"
@@ -614,7 +614,7 @@ def _unload(
614614
if partitioned_by:
615615
unload_parameters += f" , partitioned_by=ARRAY{partitioned_by}"
616616

617-
sql = f"UNLOAD ({sql}) TO '{path}' WITH ({unload_parameters})"
617+
sql = f"UNLOAD ({sql}) TO '{s3_output}' WITH ({unload_parameters})"
618618
_logger.debug("Executing unload query: %s", sql)
619619
try:
620620
query_id: str = _start_query_execution(

awswrangler/athena/_utils.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,20 @@ def _start_query_execution(
9393
args: dict[str, Any] = {"QueryString": sql}
9494

9595
# s3_output
96-
args["ResultConfiguration"] = {
97-
"OutputLocation": _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=boto3_session)
98-
}
96+
if s3_output:
97+
args["ResultConfiguration"] = {"OutputLocation": s3_output}
9998

10099
# encryption
101100
if wg_config.enforced is True:
101+
if "ResultConfiguration" not in args:
102+
args["ResultConfiguration"] = {}
102103
if wg_config.encryption is not None:
103104
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": wg_config.encryption}
104105
if wg_config.kms_key is not None:
105106
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = wg_config.kms_key
106107
elif encryption is not None:
108+
if "ResultConfiguration" not in args:
109+
args["ResultConfiguration"] = {}
107110
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": encryption}
108111
if kms_key is not None:
109112
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = kms_key
@@ -146,6 +149,29 @@ def _get_default_workgroup_config() -> _WorkGroupConfig:
146149
return wg_config
147150

148151

152+
def _get_workgroup_config(session: boto3.Session | None = None, workgroup: str = "primary") -> _WorkGroupConfig:
153+
enforced: bool
154+
wg_s3_output: str | None
155+
wg_encryption: str | None
156+
wg_kms_key: str | None
157+
158+
enforced, wg_s3_output, wg_encryption, wg_kms_key = False, None, None, None
159+
if workgroup is not None:
160+
res = get_work_group(workgroup=workgroup, boto3_session=session)
161+
enforced = res["WorkGroup"]["Configuration"]["EnforceWorkGroupConfiguration"]
162+
config: dict[str, Any] = res["WorkGroup"]["Configuration"].get("ResultConfiguration")
163+
if config is not None:
164+
wg_s3_output = config.get("OutputLocation")
165+
encrypt_config: dict[str, str] | None = config.get("EncryptionConfiguration")
166+
wg_encryption = None if encrypt_config is None else encrypt_config.get("EncryptionOption")
167+
wg_kms_key = None if encrypt_config is None else encrypt_config.get("KmsKey")
168+
wg_config: _WorkGroupConfig = _WorkGroupConfig(
169+
enforced=enforced, s3_output=wg_s3_output, encryption=wg_encryption, kms_key=wg_kms_key
170+
)
171+
_logger.debug("Workgroup config:\n%s", wg_config)
172+
return wg_config
173+
174+
149175
def _fetch_txt_result(
150176
query_metadata: _QueryMetadata,
151177
keep_files: bool,
@@ -767,8 +793,7 @@ def create_ctas_table(
767793
fully_qualified_name = f'"{ctas_database}"."{ctas_table}"'
768794

769795
wg_config: _WorkGroupConfig = _get_default_workgroup_config()
770-
s3_output = _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=boto3_session)
771-
s3_output = s3_output[:-1] if s3_output[-1] == "/" else s3_output
796+
s3_output = s3_output[:-1] if s3_output and s3_output[-1] == "/" else s3_output
772797
# If the workgroup enforces an external location, then it overrides the user supplied argument
773798
external_location_str: str = (
774799
f" external_location = '{s3_output}/{ctas_table}',\n" if (not wg_config.enforced) and (s3_output) else ""

0 commit comments

Comments
 (0)