Skip to content

Commit 5821c9a

Browse files
authored
Merge pull request #208 from awslabs/emr-6
Add support for Docker and custom classification on EMR
2 parents a837658 + f0f154b commit 5821c9a

File tree

13 files changed

+1072
-79
lines changed

13 files changed

+1072
-79
lines changed

awswrangler/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@
99

1010
from awswrangler import athena, catalog, cloudwatch, db, emr, exceptions, s3 # noqa
1111
from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa
12+
from awswrangler._utils import get_account_id # noqa
1213

1314
logging.getLogger("awswrangler").addHandler(logging.NullHandler())

awswrangler/_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,16 @@ def ensure_postgresql_casts():
166166
def get_directory(path: str) -> str:
167167
"""Extract directory path."""
168168
return path.rsplit(sep="/", maxsplit=1)[0] + "/"
169+
170+
171+
def get_account_id(boto3_session: Optional[boto3.Session] = None) -> str:
172+
"""Get Account ID."""
173+
session: boto3.Session = ensure_session(session=boto3_session)
174+
return client(service_name="sts", session=session).get_caller_identity().get("Account")
175+
176+
177+
def get_region_from_subnet(subnet_id: str, boto3_session: Optional[boto3.Session] = None) -> str:
178+
"""Extract region from Subnet ID."""
179+
session: boto3.Session = ensure_session(session=boto3_session)
180+
client_ec2: boto3.client = client(service_name="ec2", session=session)
181+
return client_ec2.describe_subnets(SubnetIds=[subnet_id])["Subnets"][0]["AvailabilityZone"][:9]

awswrangler/athena.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def create_athena_bucket(boto3_session: Optional[boto3.Session] = None) -> str:
6868
6969
"""
7070
session: boto3.Session = _utils.ensure_session(session=boto3_session)
71-
account_id: str = _utils.client(service_name="sts", session=session).get_caller_identity().get("Account")
71+
account_id: str = _utils.get_account_id(boto3_session=session)
7272
region_name: str = str(session.region_name).lower()
7373
s3_output = f"s3://aws-athena-query-results-{account_id}-{region_name}/"
7474
s3_resource = session.resource("s3")

0 commit comments

Comments
 (0)