Skip to content

Commit 9029a49

Browse files
feat: EMR Serverless (#2304)
* feat: EMR Serverless - add create_application * Move _get_unique_suffix to test._utils * Add type overload for EMR Serverless boto3 client * Add EMR Serverless create_application test case * Add EMR Serverless - run_job & wait_job * Docstrings & minor fixes * [skip ci] Fix EMR Serverless client stub * [skip ci] Add EMR Serverless IAM Role infra * [skip ci] Add job args typed dicts & update docstrings * [skip ci] Add basic EMR Serverless tutorial * [skip ci] Add emr-serverless boto3 stub * [skip-ci] PR feedback * Re-generate poery.lock * Revert "Re-generate poery.lock" This reverts commit f4dc2ad. * [skip-ci] Typing fixes * [skip-ci] Docs formatting fix --------- Co-authored-by: Lucas Hanson <[email protected]>
1 parent fa3d2ba commit 9029a49

File tree

15 files changed

+3029
-2327
lines changed

15 files changed

+3029
-2327
lines changed

awswrangler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
data_quality,
1717
dynamodb,
1818
emr,
19+
emr_serverless,
1920
exceptions,
2021
lakeformation,
2122
mysql,
@@ -44,6 +45,7 @@
4445
"chime",
4546
"cloudwatch",
4647
"emr",
48+
"emr_serverless",
4749
"data_api",
4850
"data_quality",
4951
"dynamodb",

awswrangler/_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class _ConfigArg(NamedTuple):
4545
"lakeformation_query_wait_polling_delay": _ConfigArg(dtype=float, nullable=False),
4646
"neptune_load_wait_polling_delay": _ConfigArg(dtype=float, nullable=False),
4747
"timestream_batch_load_wait_polling_delay": _ConfigArg(dtype=float, nullable=False),
48+
"emr_serverless_job_wait_polling_delay": _ConfigArg(dtype=float, nullable=False),
4849
"s3_block_size": _ConfigArg(dtype=int, nullable=False, enforced=True),
4950
"workgroup": _ConfigArg(dtype=str, nullable=False, enforced=True),
5051
"chunksize": _ConfigArg(dtype=int, nullable=False, enforced=True),
@@ -377,6 +378,15 @@ def timestream_batch_load_wait_polling_delay(self) -> float:
377378
def timestream_batch_load_wait_polling_delay(self, value: float) -> None:
378379
self._set_config_value(key="timestream_batch_load_wait_polling_delay", value=value)
379380

381+
@property
382+
def emr_serverless_job_wait_polling_delay(self) -> float:
383+
"""Property emr_serverless_job_wait_polling_delay."""
384+
return cast(float, self["emr_serverless_job_wait_polling_delay"])
385+
386+
@emr_serverless_job_wait_polling_delay.setter
387+
def emr_serverless_job_wait_polling_delay(self, value: float) -> None:
388+
self._set_config_value(key="emr_serverless_job_wait_polling_delay", value=value)
389+
380390
@property
381391
def s3_block_size(self) -> int:
382392
"""Property s3_block_size."""

awswrangler/_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from mypy_boto3_dynamodb import DynamoDBClient, DynamoDBServiceResource
4949
from mypy_boto3_ec2 import EC2Client
5050
from mypy_boto3_emr.client import EMRClient
51+
from mypy_boto3_emr_serverless import EMRServerlessClient
5152
from mypy_boto3_glue import GlueClient
5253
from mypy_boto3_kms.client import KMSClient
5354
from mypy_boto3_lakeformation.client import LakeFormationClient
@@ -70,6 +71,7 @@
7071
"dynamodb",
7172
"ec2",
7273
"emr",
74+
"emr-serverless",
7375
"glue",
7476
"kms",
7577
"lakeformation",
@@ -334,6 +336,16 @@ def client(
334336
...
335337

336338

339+
@overload
340+
def client(
341+
service_name: 'Literal["emr-serverless"]',
342+
session: Optional[boto3.Session] = None,
343+
botocore_config: Optional[Config] = None,
344+
verify: Optional[Union[str, bool]] = None,
345+
) -> "EMRServerlessClient":
346+
...
347+
348+
337349
@overload
338350
def client(
339351
service_name: 'Literal["glue"]',

awswrangler/dynamodb/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def _serialize_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
160160

161161
if "FilterExpression" in kwargs and not isinstance(kwargs["FilterExpression"], str):
162162
builder = ConditionExpressionBuilder()
163-
exp_string, names, values = builder.build_expression(kwargs["FilterExpression"], False) # type: ignore[assignment]
163+
exp_string, names, values = builder.build_expression(kwargs["FilterExpression"], False)
164164
kwargs["FilterExpression"] = exp_string
165165

166166
if "ExpressionAttributeNames" in kwargs:

awswrangler/emr_serverless.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
"""EMR Serverless module."""
2+
3+
import logging
4+
import pprint
5+
import time
6+
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
7+
8+
import boto3
9+
from typing_extensions import NotRequired, Required
10+
11+
from awswrangler import _utils, exceptions
12+
from awswrangler._config import apply_configs
13+
from awswrangler.annotations import Experimental
14+
15+
_logger: logging.Logger = logging.getLogger(__name__)
16+
17+
_EMR_SERVERLESS_JOB_WAIT_POLLING_DELAY: float = 5 # SECONDS
18+
_EMR_SERVERLESS_JOB_FINAL_STATES: List[str] = ["SUCCESS", "FAILED", "CANCELLED"]
19+
20+
21+
class SparkSubmitJobArgs(TypedDict):
22+
"""Typed dictionary defining the Spark submit job arguments."""
23+
24+
entryPoint: Required[str]
25+
"""The entry point for the Spark submit job run."""
26+
entryPointArguments: NotRequired[List[str]]
27+
"""The arguments for the Spark submit job run."""
28+
sparkSubmitParameters: NotRequired[str]
29+
"""The parameters for the Spark submit job run."""
30+
31+
32+
class HiveRunJobArgs(TypedDict):
33+
"""Typed dictionary defining the Hive job run arguments."""
34+
35+
query: Required[str]
36+
"""The S3 location of the query file for the Hive job run."""
37+
initQueryFile: NotRequired[str]
38+
"""The S3 location of the query file for the Hive job run."""
39+
parameters: NotRequired[str]
40+
"""The parameters for the Hive job run."""
41+
42+
43+
@Experimental
44+
def create_application(
45+
name: str,
46+
release_label: str,
47+
application_type: Literal["Spark", "Hive"] = "Spark",
48+
initial_capacity: Optional[Dict[str, str]] = None,
49+
maximum_capacity: Optional[Dict[str, str]] = None,
50+
tags: Optional[Dict[str, str]] = None,
51+
autostart: bool = True,
52+
autostop: bool = True,
53+
idle_timeout: int = 15,
54+
network_configuration: Optional[Dict[str, str]] = None,
55+
architecture: Literal["ARM64", "X86_64"] = "X86_64",
56+
image_uri: Optional[str] = None,
57+
worker_type_specifications: Optional[Dict[str, str]] = None,
58+
boto3_session: Optional[boto3.Session] = None,
59+
) -> str:
60+
"""
61+
Create an EMR Serverless application.
62+
63+
https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html
64+
65+
Parameters
66+
----------
67+
name : str
68+
Name of EMR Serverless appliation
69+
release_label : str
70+
Release label e.g. `emr-6.10.0`
71+
application_type : str, optional
72+
Application type: "Spark" or "Hive". Defaults to "Spark".
73+
initial_capacity : Dict[str, str], optional
74+
The capacity to initialize when the application is created.
75+
maximum_capacity : Dict[str, str], optional
76+
The maximum capacity to allocate when the application is created.
77+
This is cumulative across all workers at any given point in time,
78+
not just when an application is created. No new resources will
79+
be created once any one of the defined limits is hit.
80+
tags : Dict[str, str], optional
81+
Key/Value collection to put tags on the application.
82+
e.g. {"foo": "boo", "bar": "xoo"})
83+
autostart : bool, optional
84+
Enables the application to automatically start on job submission. Defaults to true.
85+
autostop : bool, optional
86+
Enables the application to automatically stop after a certain amount of time being idle. Defaults to true.
87+
idle_timeout : int, optional
88+
The amount of idle time in minutes after which your application will automatically stop. Defaults to 15 minutes.
89+
network_configuration : Dict[str, str], optional
90+
The network configuration for customer VPC connectivity.
91+
architecture : str, optional
92+
The CPU architecture of an application: "ARM64" or "X86_64". Defaults to "X86_64".
93+
image_uri : str, optional
94+
The URI of an image in the Amazon ECR registry.
95+
worker_type_specifications : Dict[str, str], optional
96+
The key-value pairs that specify worker type.
97+
boto3_session : boto3.Session(), optional
98+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
99+
100+
Returns
101+
-------
102+
str
103+
Application Id.
104+
"""
105+
emr_serverless = _utils.client(service_name="emr-serverless", session=boto3_session)
106+
application_args: Dict[str, Any] = {
107+
"name": name,
108+
"releaseLabel": release_label,
109+
"type": application_type,
110+
"autoStartConfiguration": {
111+
"enabled": autostart,
112+
},
113+
"autoStopConfiguration": {
114+
"enabled": autostop,
115+
"idleTimeoutMinutes": idle_timeout,
116+
},
117+
"architecture": architecture,
118+
}
119+
if initial_capacity:
120+
application_args["initialCapacity"] = initial_capacity
121+
if maximum_capacity:
122+
application_args["maximumCapacity"] = maximum_capacity
123+
if tags:
124+
application_args["tags"] = tags
125+
if network_configuration:
126+
application_args["networkConfiguration"] = network_configuration
127+
if worker_type_specifications:
128+
application_args["workerTypeSpecifications"] = worker_type_specifications
129+
if image_uri:
130+
application_args["imageConfiguration"] = {
131+
"imageUri": image_uri,
132+
}
133+
response: Dict[str, str] = emr_serverless.create_application(**application_args) # type: ignore[assignment]
134+
_logger.debug("response: \n%s", pprint.pformat(response))
135+
return response["applicationId"]
136+
137+
138+
@Experimental
139+
@apply_configs
140+
def run_job(
141+
application_id: str,
142+
execution_role_arn: str,
143+
job_driver_args: Union[Dict[str, Any], SparkSubmitJobArgs, HiveRunJobArgs],
144+
job_type: Literal["Spark", "Hive"] = "Spark",
145+
wait: bool = True,
146+
configuration_overrides: Optional[Dict[str, Any]] = None,
147+
tags: Optional[Dict[str, str]] = None,
148+
execution_timeout: Optional[int] = None,
149+
name: Optional[str] = None,
150+
emr_serverless_job_wait_polling_delay: float = _EMR_SERVERLESS_JOB_WAIT_POLLING_DELAY,
151+
boto3_session: Optional[boto3.Session] = None,
152+
) -> Union[str, Dict[str, Any]]:
153+
"""
154+
Run an EMR serverless job.
155+
156+
https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html
157+
158+
Parameters
159+
----------
160+
application_id : str
161+
The id of the application on which to run the job.
162+
execution_role_arn : str
163+
The execution role ARN for the job run.
164+
job_driver_args : Union[Dict[str, str], SparkSubmitJobArgs, HiveRunJobArgs]
165+
The job driver arguments for the job run.
166+
job_type : str, optional
167+
Type of the job: "Spark" or "Hive". Defaults to "Spark".
168+
wait : bool, optional
169+
Whether to wait for the job completion or not. Defaults to true.
170+
configuration_overrides : Dict[str, str], optional
171+
The configuration overrides for the job run.
172+
tags : Dict[str, str], optional
173+
Key/Value collection to put tags on the application.
174+
e.g. {"foo": "boo", "bar": "xoo"})
175+
execution_timeout : int, optional
176+
The maximum duration for the job run to run. If the job run runs beyond this duration,
177+
it will be automatically cancelled.
178+
name : str, optional
179+
Name of the job.
180+
emr_serverless_job_wait_polling_delay : int, optional
181+
Time to wait between polling attempts.
182+
boto3_session : boto3.Session(), optional
183+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
184+
185+
Returns
186+
-------
187+
Union[str, Dict[str, Any]]
188+
Job Id if wait=False, or job run details.
189+
"""
190+
emr_serverless = _utils.client(service_name="emr-serverless", session=boto3_session)
191+
job_args: Dict[str, Any] = {
192+
"applicationId": application_id,
193+
"executionRoleArn": execution_role_arn,
194+
}
195+
if job_type == "Spark":
196+
job_args["jobDriver"] = {
197+
"sparkSubmit": job_driver_args,
198+
}
199+
elif job_type == "Hive":
200+
job_args["jobDriver"] = {
201+
"hive": job_driver_args,
202+
}
203+
else:
204+
raise exceptions.InvalidArgumentValue(f"Unsupported job type `{job_type}`")
205+
206+
if configuration_overrides:
207+
job_args["configurationOverrides"] = configuration_overrides
208+
if tags:
209+
job_args["tags"] = tags
210+
if execution_timeout:
211+
job_args["executionTimeoutMinutes"] = execution_timeout
212+
if name:
213+
job_args["name"] = name
214+
response = emr_serverless.start_job_run(**job_args)
215+
_logger.debug("Job run response: %s", response)
216+
job_run_id: str = response["jobRunId"]
217+
if wait:
218+
return wait_job(
219+
application_id=application_id,
220+
job_run_id=job_run_id,
221+
emr_serverless_job_wait_polling_delay=emr_serverless_job_wait_polling_delay,
222+
)
223+
return job_run_id
224+
225+
226+
@Experimental
227+
@apply_configs
228+
def wait_job(
229+
application_id: str,
230+
job_run_id: str,
231+
emr_serverless_job_wait_polling_delay: float = _EMR_SERVERLESS_JOB_WAIT_POLLING_DELAY,
232+
boto3_session: Optional[boto3.Session] = None,
233+
) -> Dict[str, Any]:
234+
"""
235+
Wait for the EMR Serverless job to finish.
236+
237+
https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html
238+
239+
Parameters
240+
----------
241+
application_id : str
242+
The id of the application on which the job is running.
243+
job_run_id : str
244+
The id of the job.
245+
emr_serverless_job_wait_polling_delay : int, optional
246+
Time to wait between polling attempts.
247+
boto3_session : boto3.Session(), optional
248+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
249+
250+
Returns
251+
-------
252+
Dict[str, Any]
253+
Job run details.
254+
"""
255+
emr_serverless = _utils.client(service_name="emr-serverless", session=boto3_session)
256+
response = emr_serverless.get_job_run(
257+
applicationId=application_id,
258+
jobRunId=job_run_id,
259+
)
260+
state = response["jobRun"]["state"]
261+
while state not in _EMR_SERVERLESS_JOB_FINAL_STATES:
262+
time.sleep(emr_serverless_job_wait_polling_delay)
263+
response = emr_serverless.get_job_run(
264+
applicationId=application_id,
265+
jobRunId=job_run_id,
266+
)
267+
state = response["jobRun"]["state"]
268+
_logger.debug("Job state: %s", state)
269+
if state != "SUCCESS":
270+
_logger.debug("Job run response: %s", response)
271+
raise exceptions.EMRServerlessJobError(response.get("jobRun", {}).get("stateDetails"))
272+
return response # type: ignore[return-value]

awswrangler/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,7 @@ class TimestreamLoadError(Exception):
143143

144144
class NeptuneLoadError(Exception):
145145
"""NeptuneLoadError."""
146+
147+
148+
class EMRServerlessJobError(Exception):
149+
"""EMRServerlessJobError."""

docs/source/api.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ API Reference
1818
* `DynamoDB`_
1919
* `Amazon Timestream`_
2020
* `Amazon EMR`_
21+
* `Amazon EMR Serverless`_
2122
* `Amazon CloudWatch Logs`_
2223
* `Amazon QuickSight`_
2324
* `AWS STS`_
@@ -360,6 +361,18 @@ Amazon EMR
360361
submit_steps
361362
terminate_cluster
362363

364+
Amazon EMR Serverless
365+
---------------------
366+
367+
.. currentmodule:: awswrangler.emr_serverless
368+
369+
.. autosummary::
370+
:toctree: stubs
371+
372+
create_application
373+
run_job
374+
wait_job
375+
363376
Amazon CloudWatch Logs
364377
----------------------
365378

0 commit comments

Comments
 (0)