Skip to content

Commit 9ce624b

Browse files
committed
Add support to EMR with Docker
1 parent 2f91a50 commit 9ce624b

File tree

13 files changed

+869
-79
lines changed

13 files changed

+869
-79
lines changed

awswrangler/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@
99

1010
from awswrangler import athena, catalog, cloudwatch, db, emr, exceptions, s3 # noqa
1111
from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa
12+
from awswrangler._utils import get_account_id # noqa
1213

1314
logging.getLogger("awswrangler").addHandler(logging.NullHandler())

awswrangler/_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,16 @@ def ensure_postgresql_casts():
166166
def get_directory(path: str) -> str:
167167
"""Extract directory path."""
168168
return path.rsplit(sep="/", maxsplit=1)[0] + "/"
169+
170+
171+
def get_account_id(boto3_session: Optional[boto3.Session] = None) -> str:
172+
"""Get Account ID."""
173+
session: boto3.Session = ensure_session(session=boto3_session)
174+
return client(service_name="sts", session=session).get_caller_identity().get("Account")
175+
176+
177+
def get_region_from_subnet(subnet_id: str, boto3_session: Optional[boto3.Session] = None) -> str:
178+
"""Extract region from Subnet ID."""
179+
session: boto3.Session = ensure_session(session=boto3_session)
180+
client_ec2: boto3.client = client(service_name="ec2", session=session)
181+
return client_ec2.describe_subnets(SubnetIds=[subnet_id])["Subnets"][0]["AvailabilityZone"][:9]

awswrangler/athena.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def create_athena_bucket(boto3_session: Optional[boto3.Session] = None) -> str:
6868
6969
"""
7070
session: boto3.Session = _utils.ensure_session(session=boto3_session)
71-
account_id: str = _utils.client(service_name="sts", session=session).get_caller_identity().get("Account")
71+
account_id: str = _utils.get_account_id(boto3_session=session)
7272
region_name: str = str(session.region_name).lower()
7373
s3_output = f"s3://aws-athena-query-results-{account_id}-{region_name}/"
7474
s3_resource = session.resource("s3")

awswrangler/emr.py

Lines changed: 271 additions & 64 deletions
Large diffs are not rendered by default.

awswrangler/s3.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,40 @@ def does_object_exist(path: str, boto3_session: Optional[boto3.Session] = None)
111111
raise ex # pragma: no cover
112112

113113

114+
def list_directories(path: str, boto3_session: Optional[boto3.Session] = None) -> List[str]:
115+
"""List Amazon S3 objects from a prefix.
116+
117+
Parameters
118+
----------
119+
path : str
120+
S3 path (e.g. s3://bucket/prefix).
121+
boto3_session : boto3.Session(), optional
122+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
123+
124+
Returns
125+
-------
126+
List[str]
127+
List of objects paths.
128+
129+
Examples
130+
--------
131+
Using the default boto3 session
132+
133+
>>> import awswrangler as wr
134+
>>> wr.s3.list_objects('s3://bucket/prefix/')
135+
['s3://bucket/prefix/dir0', 's3://bucket/prefix/dir1', 's3://bucket/prefix/dir2']
136+
137+
Using a custom boto3 session
138+
139+
>>> import boto3
140+
>>> import awswrangler as wr
141+
>>> wr.s3.list_objects('s3://bucket/prefix/', boto3_session=boto3.Session())
142+
['s3://bucket/prefix/dir0', 's3://bucket/prefix/dir1', 's3://bucket/prefix/dir2']
143+
144+
"""
145+
return _list_objects(path=path, delimiter="/", boto3_session=boto3_session)
146+
147+
114148
def list_objects(path: str, boto3_session: Optional[boto3.Session] = None) -> List[str]:
115149
"""List Amazon S3 objects from a prefix.
116150
@@ -142,20 +176,37 @@ def list_objects(path: str, boto3_session: Optional[boto3.Session] = None) -> Li
142176
['s3://bucket/prefix0', 's3://bucket/prefix1', 's3://bucket/prefix2']
143177
144178
"""
179+
return _list_objects(path=path, delimiter=None, boto3_session=boto3_session)
180+
181+
182+
def _list_objects(
183+
path: str, delimiter: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
184+
) -> List[str]:
145185
client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session)
146186
paginator = client_s3.get_paginator("list_objects_v2")
147187
bucket: str
148188
prefix: str
149189
bucket, prefix = _utils.parse_path(path=path)
150-
response_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, PaginationConfig={"PageSize": 1000})
190+
args: Dict[str, Any] = {"Bucket": bucket, "Prefix": prefix, "PaginationConfig": {"PageSize": 1000}}
191+
if delimiter is not None:
192+
args["Delimiter"] = delimiter
193+
response_iterator = paginator.paginate(**args)
151194
paths: List[str] = []
152195
for page in response_iterator:
153-
contents: Optional[List] = page.get("Contents")
154-
if contents is not None:
155-
for content in contents:
156-
if (content is not None) and ("Key" in content):
157-
key: str = content["Key"]
158-
paths.append(f"s3://{bucket}/{key}")
196+
if delimiter is None:
197+
contents: Optional[List[Optional[Dict[str, str]]]] = page.get("Contents")
198+
if contents is not None:
199+
for content in contents:
200+
if (content is not None) and ("Key" in content):
201+
key: str = content["Key"]
202+
paths.append(f"s3://{bucket}/{key}")
203+
else:
204+
prefixes: Optional[List[Optional[Dict[str, str]]]] = page.get("CommonPrefixes")
205+
if prefixes is not None:
206+
for pfx in prefixes:
207+
if (pfx is not None) and ("Prefix" in pfx):
208+
key = pfx["Prefix"]
209+
paths.append(f"s3://{bucket}/{key}")
159210
return paths
160211

161212

docs/source/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Amazon S3
1616
does_object_exist
1717
get_bucket_region
1818
list_objects
19+
list_directories
1920
read_csv
2021
read_fwf
2122
read_json
@@ -115,6 +116,7 @@ EMR
115116
submit_steps
116117
build_step
117118
get_step_state
119+
update_ecr_credentials
118120

119121
CloudWatch Logs
120122
---------------

requirements-dev.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ twine~=3.1.1
1717
wheel~=0.34.2
1818
sphinx~=3.0.1
1919
sphinx_bootstrap_theme~=0.7.1
20-
moto~=1.3.14
20+
moto~=1.3.14
21+
jupyterlab~=2.1.1

testing/test_awswrangler/test_cloudwatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def loggroup(cloudformation_outputs):
4848
def test_query_cancelled(loggroup):
4949
client_logs = boto3.client("logs")
5050
query_id = wr.cloudwatch.start_query(
51-
log_group_names=[loggroup], query="fields @timestamp, @message | sort @timestamp desc | limit 5"
51+
log_group_names=[loggroup], query="fields @timestamp, @message | sort @timestamp desc"
5252
)
5353
client_logs.stop_query(queryId=query_id)
5454
with pytest.raises(exceptions.QueryCancelled):

testing/test_awswrangler/test_data_lake.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ def test_athena_ctas(bucket, database, kms_key):
127127
partition_cols=["par0", "par1"],
128128
)["paths"]
129129
wr.s3.wait_objects_exist(paths=paths)
130+
dirs = wr.s3.list_directories(path=f"s3://{bucket}/test_athena_ctas/")
131+
for d in dirs:
132+
assert d.startswith(f"s3://{bucket}/test_athena_ctas/par0=")
130133
df = wr.s3.read_parquet_table(table="test_athena_ctas", database=database)
131134
assert len(df.index) == 3
132135
ensure_data_types(df=df, has_list=True)

testing/test_awswrangler/test_emr.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,36 @@ def test_cluster_single_node(bucket, cloudformation_outputs):
146146
wr.emr.submit_steps(cluster_id=cluster_id, steps=steps)
147147
wr.emr.terminate_cluster(cluster_id=cluster_id)
148148
wr.s3.delete_objects(f"s3://{bucket}/emr-logs/")
149+
150+
151+
def test_default_logging_path(cloudformation_outputs):
152+
path = wr.emr._get_default_logging_path(subnet_id=cloudformation_outputs["SubnetId"])
153+
assert path.startswith("s3://aws-logs-")
154+
assert path.endswith("/elasticmapreduce/")
155+
with pytest.raises(wr.exceptions.InvalidArgumentCombination):
156+
wr.emr._get_default_logging_path()
157+
158+
159+
def test_docker(cloudformation_outputs):
160+
cluster_id = wr.emr.create_cluster(
161+
subnet_id=cloudformation_outputs["SubnetId"],
162+
docker=True,
163+
spark_docker=True,
164+
spark_docker_image="787535711150.dkr.ecr.us-east-1.amazonaws.com/docker-emr:docker-emr",
165+
hive_docker=True,
166+
ecr_credentials_step=True,
167+
custom_classifications=[
168+
{
169+
"Classification": "livy-conf",
170+
"Properties": {
171+
"livy.spark.master": "yarn",
172+
"livy.spark.deploy-mode": "cluster",
173+
"livy.server.session.timeout": "16h",
174+
},
175+
}
176+
],
177+
steps=[wr.emr.build_step("spark-submit --deploy-mode cluster s3://igor-tavares/emr.py")],
178+
)
179+
wr.emr.submit_step(cluster_id=cluster_id, command="spark-submit --deploy-mode cluster s3://igor-tavares/emr.py")
180+
wr.emr.update_ecr_credentials(cluster_id=cluster_id)
181+
wr.emr.terminate_cluster(cluster_id=cluster_id)

0 commit comments

Comments
 (0)