Skip to content

Commit 311563e

Browse files
authored
feat: Apache Spark on Amazon Athena - wr.athena.create_spark_session & wr.athena.run_spark_calculation (#2314)
* feat: Spark on Athena - checkpoint * feat: Spark on Athena - temp remove result handling, add create_spark_session & fix types * [skip ci] Docstrings * [skip ci] Fix output types & add test case spark code * Remove comments * Upd api docs * [skip ci] Add tutorial * [skip ci] Add IAM role * [skip ci] Update docstrings * [skip ci] Add examples * [skip ci] Reuse inline LF policy
1 parent 9f4d83d commit 311563e

File tree

7 files changed

+492
-20
lines changed

7 files changed

+492
-20
lines changed

awswrangler/athena/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
start_query_execution,
77
wait_query,
88
)
9+
from awswrangler.athena._spark import create_spark_session, run_spark_calculation
910
from awswrangler.athena._read import ( # noqa
1011
get_query_results,
1112
read_sql_query,
@@ -42,6 +43,8 @@
4243
"generate_create_query",
4344
"list_query_executions",
4445
"repair_table",
46+
"create_spark_session",
47+
"run_spark_calculation",
4548
"create_ctas_table",
4649
"show_create_table",
4750
"start_query_execution",

awswrangler/athena/_spark.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""Apache Spark on Amazon Athena Module."""
2+
# pylint: disable=too-many-lines
3+
import logging
4+
import time
5+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
6+
7+
import boto3
8+
9+
from awswrangler import _utils, exceptions
10+
11+
_logger: logging.Logger = logging.getLogger(__name__)
12+
13+
if TYPE_CHECKING:
14+
from mypy_boto3_athena.type_defs import (
15+
EngineConfigurationTypeDef,
16+
GetCalculationExecutionResponseTypeDef,
17+
GetCalculationExecutionStatusResponseTypeDef,
18+
GetSessionStatusResponseTypeDef,
19+
)
20+
21+
_SESSION_FINAL_STATES: List[str] = ["IDLE", "TERMINATED", "DEGRADED", "FAILED"]
22+
_CALCULATION_EXECUTION_FINAL_STATES: List[str] = ["COMPLETED", "FAILED", "CANCELED"]
23+
_SESSION_WAIT_POLLING_DELAY: float = 5.0 # SECONDS
24+
_CALCULATION_EXECUTION_WAIT_POLLING_DELAY: float = 5.0 # SECONDS
25+
26+
27+
def _wait_session(
28+
session_id: str,
29+
boto3_session: Optional[boto3.Session] = None,
30+
athena_session_wait_polling_delay: float = _SESSION_WAIT_POLLING_DELAY,
31+
) -> "GetSessionStatusResponseTypeDef":
32+
client_athena = _utils.client(service_name="athena", session=boto3_session)
33+
34+
response: "GetSessionStatusResponseTypeDef" = client_athena.get_session_status(SessionId=session_id)
35+
state: str = response["Status"]["State"]
36+
37+
while state not in _SESSION_FINAL_STATES:
38+
time.sleep(athena_session_wait_polling_delay)
39+
response = client_athena.get_session_status(SessionId=session_id)
40+
state = response["Status"]["State"]
41+
_logger.debug("Session state: %s", state)
42+
_logger.debug("Session state change reason: %s", response["Status"].get("StateChangeReason"))
43+
if state in ["FAILED", "DEGRADED", "TERMINATED"]:
44+
raise exceptions.SessionFailed(response["Status"].get("StateChangeReason"))
45+
return response
46+
47+
48+
def _wait_calculation_execution(
49+
calculation_execution_id: str,
50+
boto3_session: Optional[boto3.Session] = None,
51+
athena_calculation_execution_wait_polling_delay: float = _CALCULATION_EXECUTION_WAIT_POLLING_DELAY,
52+
) -> "GetCalculationExecutionStatusResponseTypeDef":
53+
client_athena = _utils.client(service_name="athena", session=boto3_session)
54+
55+
response: "GetCalculationExecutionStatusResponseTypeDef" = client_athena.get_calculation_execution_status(
56+
CalculationExecutionId=calculation_execution_id
57+
)
58+
state: str = response["Status"]["State"]
59+
60+
while state not in _CALCULATION_EXECUTION_FINAL_STATES:
61+
time.sleep(athena_calculation_execution_wait_polling_delay)
62+
response = client_athena.get_calculation_execution_status(CalculationExecutionId=calculation_execution_id)
63+
state = response["Status"]["State"]
64+
_logger.debug("Calculation execution state: %s", state)
65+
_logger.debug("Calculation execution state change reason: %s", response["Status"].get("StateChangeReason"))
66+
if state in ["CANCELED", "FAILED"]:
67+
raise exceptions.CalculationFailed(response["Status"].get("StateChangeReason"))
68+
return response
69+
70+
71+
def _get_calculation_execution_results(
72+
calculation_execution_id: str,
73+
boto3_session: Optional[boto3.Session] = None,
74+
) -> Dict[str, Any]:
75+
client_athena = _utils.client(service_name="athena", session=boto3_session)
76+
77+
_wait_calculation_execution(
78+
calculation_execution_id=calculation_execution_id,
79+
boto3_session=boto3_session,
80+
)
81+
82+
response: "GetCalculationExecutionResponseTypeDef" = client_athena.get_calculation_execution(
83+
CalculationExecutionId=calculation_execution_id,
84+
)
85+
return cast(Dict[str, Any], response)
86+
87+
88+
def create_spark_session(
89+
workgroup: str,
90+
coordinator_dpu_size: int = 1,
91+
max_concurrent_dpus: int = 5,
92+
default_executor_dpu_size: int = 1,
93+
additional_configs: Optional[Dict[str, Any]] = None,
94+
idle_timeout: int = 15,
95+
boto3_session: Optional[boto3.Session] = None,
96+
) -> str:
97+
"""
98+
Create session and wait until ready to accept calculations.
99+
100+
Parameters
101+
----------
102+
workgroup : str
103+
Athena workgroup name. Must be Spark-enabled.
104+
coordinator_dpu_size : int, optional
105+
The number of DPUs to use for the coordinator. A coordinator is a special executor that orchestrates
106+
processing work and manages other executors in a notebook session. The default is 1.
107+
max_concurrent_dpus : int, optional
108+
The maximum number of DPUs that can run concurrently. The default is 5.
109+
default_executor_dpu_size: int, optional
110+
The default number of DPUs to use for executors. The default is 1.
111+
additional_configs : Dict[str, Any], optional
112+
Contains additional engine parameter mappings in the form of key-value pairs.
113+
idle_timeout : int, optional
114+
The idle timeout in minutes for the session. The default is 15.
115+
boto3_session : boto3.Session(), optional
116+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
117+
118+
Returns
119+
-------
120+
str
121+
Session id
122+
123+
Examples
124+
--------
125+
>>> import awswrangler as wr
126+
>>> df = wr.athena.create_spark_session(workgroup="...", max_concurrent_dpus=10)
127+
128+
"""
129+
client_athena = _utils.client(service_name="athena", session=boto3_session)
130+
engine_configuration: "EngineConfigurationTypeDef" = {
131+
"CoordinatorDpuSize": coordinator_dpu_size,
132+
"MaxConcurrentDpus": max_concurrent_dpus,
133+
"DefaultExecutorDpuSize": default_executor_dpu_size,
134+
}
135+
if additional_configs:
136+
engine_configuration["AdditionalConfigs"] = additional_configs
137+
response = client_athena.start_session(
138+
WorkGroup=workgroup,
139+
EngineConfiguration=engine_configuration,
140+
SessionIdleTimeoutInMinutes=idle_timeout,
141+
)
142+
_logger.info("Session info:\n%s", response)
143+
session_id: str = response["SessionId"]
144+
# Wait for the session to reach IDLE state to be able to accept calculations
145+
_wait_session(
146+
session_id=session_id,
147+
boto3_session=boto3_session,
148+
)
149+
return session_id
150+
151+
152+
def run_spark_calculation(
153+
code: str,
154+
workgroup: str,
155+
session_id: Optional[str] = None,
156+
coordinator_dpu_size: int = 1,
157+
max_concurrent_dpus: int = 5,
158+
default_executor_dpu_size: int = 1,
159+
additional_configs: Optional[Dict[str, Any]] = None,
160+
idle_timeout: int = 15,
161+
boto3_session: Optional[boto3.Session] = None,
162+
) -> Dict[str, Any]:
163+
"""
164+
Execute Spark Calculation and wait for completion.
165+
166+
Parameters
167+
----------
168+
code : str
169+
A string that contains the code for the calculation.
170+
workgroup : str
171+
Athena workgroup name. Must be Spark-enabled.
172+
session_id : str, optional
173+
The session id. If not passed, a session will be started.
174+
coordinator_dpu_size : int, optional
175+
The number of DPUs to use for the coordinator. A coordinator is a special executor that orchestrates
176+
processing work and manages other executors in a notebook session. The default is 1.
177+
max_concurrent_dpus : int, optional
178+
The maximum number of DPUs that can run concurrently. The default is 5.
179+
default_executor_dpu_size: int, optional
180+
The default number of DPUs to use for executors. The default is 1.
181+
additional_configs : Dict[str, Any], optional
182+
Contains additional engine parameter mappings in the form of key-value pairs.
183+
idle_timeout : int, optional
184+
The idle timeout in minutes for the session. The default is 15.
185+
boto3_session : boto3.Session(), optional
186+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
187+
188+
Returns
189+
-------
190+
Dict[str, Any]
191+
Calculation response
192+
193+
Examples
194+
--------
195+
>>> import awswrangler as wr
196+
>>> df = wr.athena.run_spark_calculation(
197+
... code="print(spark)",
198+
... workgroup="...",
199+
... )
200+
201+
"""
202+
client_athena = _utils.client(service_name="athena", session=boto3_session)
203+
204+
session_id = (
205+
create_spark_session(
206+
workgroup=workgroup,
207+
coordinator_dpu_size=coordinator_dpu_size,
208+
max_concurrent_dpus=max_concurrent_dpus,
209+
default_executor_dpu_size=default_executor_dpu_size,
210+
additional_configs=additional_configs,
211+
idle_timeout=idle_timeout,
212+
boto3_session=boto3_session,
213+
)
214+
if not session_id
215+
else session_id
216+
)
217+
218+
response = client_athena.start_calculation_execution(
219+
SessionId=session_id,
220+
CodeBlock=code,
221+
)
222+
_logger.info("Calculation execution info:\n%s", response)
223+
224+
return _get_calculation_execution_results(
225+
calculation_execution_id=response["CalculationExecutionId"],
226+
boto3_session=boto3_session,
227+
)

awswrangler/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ class QueryCancelled(Exception):
4949
"""QueryCancelled exception."""
5050

5151

52+
class SessionFailed(Exception):
53+
"""SessionFailed exception."""
54+
55+
56+
class CalculationFailed(Exception):
57+
"""CalculationFailed exception."""
58+
59+
5260
class EmptyDataFrame(Exception):
5361
"""EmptyDataFrame exception."""
5462

docs/source/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ Amazon Athena
121121
:toctree: stubs
122122

123123
create_athena_bucket
124+
create_spark_session
124125
create_ctas_table
125126
generate_create_query
126127
get_query_columns_types
@@ -133,6 +134,7 @@ Amazon Athena
133134
read_sql_query
134135
read_sql_table
135136
repair_table
137+
run_spark_calculation
136138
start_query_execution
137139
stop_query_execution
138140
to_iceberg

test_infra/stacks/base_stack.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: str) -> None:
8787
resource_arn=self.bucket.bucket_arn,
8888
use_service_linked_role=True,
8989
)
90+
inline_lf_policies = {
91+
"GetDataAccess": iam.PolicyDocument(
92+
statements=[
93+
iam.PolicyStatement(
94+
actions=["lakeformation:GetDataAccess"],
95+
resources=["*"],
96+
),
97+
]
98+
),
99+
}
90100
glue_data_quality_role = iam.Role(
91101
self,
92102
"aws-sdk-pandas-glue-data-quality-role",
@@ -96,16 +106,7 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: str) -> None:
96106
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonS3FullAccess"),
97107
iam.ManagedPolicy.from_aws_managed_policy_name("AWSGlueConsoleFullAccess"),
98108
],
99-
inline_policies={
100-
"GetDataAccess": iam.PolicyDocument(
101-
statements=[
102-
iam.PolicyStatement(
103-
actions=["lakeformation:GetDataAccess"],
104-
resources=["*"],
105-
),
106-
]
107-
),
108-
},
109+
inline_policies=inline_lf_policies,
109110
)
110111
emr_serverless_exec_role = iam.Role(
111112
self,
@@ -116,16 +117,19 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: str) -> None:
116117
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonS3FullAccess"),
117118
iam.ManagedPolicy.from_aws_managed_policy_name("AWSGlueConsoleFullAccess"),
118119
],
119-
inline_policies={
120-
"GetDataAccess": iam.PolicyDocument(
121-
statements=[
122-
iam.PolicyStatement(
123-
actions=["lakeformation:GetDataAccess"],
124-
resources=["*"],
125-
),
126-
]
127-
),
128-
},
120+
inline_policies=inline_lf_policies,
121+
)
122+
athena_spark_exec_role = iam.Role(
123+
self,
124+
"aws-sdk-pandas-athena-spark-exec-role",
125+
role_name="AthenaSparkExecutionRole",
126+
assumed_by=iam.ServicePrincipal("athena.amazonaws.com"),
127+
managed_policies=[
128+
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonS3FullAccess"),
129+
iam.ManagedPolicy.from_aws_managed_policy_name("AWSGlueConsoleFullAccess"),
130+
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonAthenaFullAccess"),
131+
],
132+
inline_policies=inline_lf_policies,
129133
)
130134
glue_db = glue.Database(
131135
self,
@@ -199,6 +203,7 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: str) -> None:
199203
CfnOutput(self, "GlueDatabaseName", value=glue_db.database_name)
200204
CfnOutput(self, "GlueDataQualityRole", value=glue_data_quality_role.role_arn)
201205
CfnOutput(self, "EMRServerlessExecutionRoleArn", value=emr_serverless_exec_role.role_arn)
206+
CfnOutput(self, "AthenaSparkExecutionRoleArn", value=athena_spark_exec_role.role_arn)
202207
CfnOutput(self, "LogGroupName", value=log_group.log_group_name)
203208
CfnOutput(self, "LogStream", value=log_stream.log_stream_name)
204209

tests/unit/test_athena_spark.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
3+
import awswrangler as wr
4+
from tests._utils import create_workgroup
5+
6+
7+
@pytest.fixture(scope="session")
8+
def athena_spark_execution_role_arn(cloudformation_outputs):
9+
return cloudformation_outputs["AthenaSparkExecutionRoleArn"]
10+
11+
12+
@pytest.fixture(scope="session")
13+
def workgroup_spark(bucket, kms_key, athena_spark_execution_role_arn):
14+
return create_workgroup(
15+
wkg_name="aws_sdk_pandas_spark",
16+
config={
17+
"EngineVersion": {
18+
"SelectedEngineVersion": "PySpark engine version 3",
19+
},
20+
"ExecutionRole": athena_spark_execution_role_arn,
21+
"ResultConfiguration": {"OutputLocation": f"s3://{bucket}/athena_workgroup_spark/"},
22+
},
23+
)
24+
25+
26+
@pytest.mark.parametrize(
27+
"code",
28+
[
29+
"print(spark)",
30+
"""
31+
input_path = "s3://athena-examples-us-east-1/notebooks/yellow_tripdata_2016-01.parquet"
32+
output_path = "$PATH"
33+
34+
taxi_df = spark.read.format("parquet").load(input_path)
35+
36+
taxi_passenger_counts = taxi_df.groupBy("VendorID", "passenger_count").count()
37+
taxi_passenger_counts.coalesce(1).write.mode('overwrite').csv(output_path)
38+
""",
39+
],
40+
)
41+
def test_athena_spark_calculation(code, path, workgroup_spark):
42+
code = code.replace("$PATH", path)
43+
44+
result = wr.athena.run_spark_calculation(
45+
code=code,
46+
workgroup=workgroup_spark,
47+
)
48+
49+
assert result["Status"]["State"] == "COMPLETED"

0 commit comments

Comments
 (0)