Skip to content

Commit 1fc775f

Browse files
committed
Add AWS STS module.
1 parent b999930 commit 1fc775f

File tree

10 files changed

+31
-37
lines changed

10 files changed

+31
-37
lines changed

awswrangler/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77

88
import logging as _logging
99

10-
from awswrangler import athena, catalog, cloudwatch, db, emr, exceptions, quicksight, s3 # noqa
10+
from awswrangler import athena, catalog, cloudwatch, db, emr, exceptions, quicksight, s3, sts # noqa
1111
from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa
12-
from awswrangler._utils import get_account_id # noqa
1312

1413
_logging.getLogger("awswrangler").addHandler(_logging.NullHandler())

awswrangler/_utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,6 @@ def get_directory(path: str) -> str:
207207
return path.rsplit(sep="/", maxsplit=1)[0] + "/"
208208

209209

210-
def get_account_id(boto3_session: Optional[boto3.Session] = None) -> str:
211-
"""Get Account ID."""
212-
session: boto3.Session = ensure_session(session=boto3_session)
213-
return client(service_name="sts", session=session).get_caller_identity().get("Account")
214-
215-
216210
def get_region_from_subnet(subnet_id: str, boto3_session: Optional[boto3.Session] = None) -> str: # pragma: no cover
217211
"""Extract region from Subnet ID."""
218212
session: boto3.Session = ensure_session(session=boto3_session)

awswrangler/athena.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pandas as pd # type: ignore
1414
import pyarrow as pa # type: ignore
1515

16-
from awswrangler import _data_types, _utils, catalog, exceptions, s3
16+
from awswrangler import _data_types, _utils, catalog, exceptions, s3, sts
1717

1818
_logger: logging.Logger = logging.getLogger(__name__)
1919

@@ -72,7 +72,7 @@ def create_athena_bucket(boto3_session: Optional[boto3.Session] = None) -> str:
7272
7373
"""
7474
session: boto3.Session = _utils.ensure_session(session=boto3_session)
75-
account_id: str = _utils.get_account_id(boto3_session=session)
75+
account_id: str = sts.get_account_id(boto3_session=session)
7676
region_name: str = str(session.region_name).lower()
7777
s3_output = f"s3://aws-athena-query-results-{account_id}-{region_name}/"
7878
s3_resource = session.resource("s3")
@@ -501,7 +501,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals,too-man
501501

502502
if cache_info["has_valid_cache"] is True:
503503
_logger.debug("Valid cache found. Retrieving...")
504-
cache_result: Union[pd.DataFrame, Iterator[pd.DataFrame]] = None
504+
cache_result: Union[pd.DataFrame, Iterator[pd.DataFrame], None] = None
505505
try:
506506
cache_result = _resolve_query_with_cache(
507507
cache_info=cache_info,
@@ -1020,7 +1020,7 @@ def _check_for_cached_results(
10201020
if (current_timestamp - query_info["Status"]["CompletionDateTime"]).total_seconds() > max_cache_seconds:
10211021
break
10221022

1023-
comparison_query: Optional[str] = None
1023+
comparison_query: Optional[str]
10241024
if query_info["StatementType"] == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
10251025
parsed_query: Optional[str] = _parse_select_query_from_possible_ctas(query_info["Query"])
10261026
if parsed_query is not None:

awswrangler/emr.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import boto3 # type: ignore
99

10-
from awswrangler import _utils, exceptions
10+
from awswrangler import _utils, exceptions, sts
1111

1212
_logger: logging.Logger = logging.getLogger(__name__)
1313

@@ -48,7 +48,7 @@ def _get_default_logging_path(
4848
"""
4949
if account_id is None:
5050
boto3_session = _utils.ensure_session(session=boto3_session)
51-
_account_id: str = _utils.get_account_id(boto3_session=boto3_session)
51+
_account_id: str = sts.get_account_id(boto3_session=boto3_session)
5252
else:
5353
_account_id = account_id
5454
if (region is None) and (subnet_id is not None):
@@ -61,7 +61,7 @@ def _get_default_logging_path(
6161

6262

6363
def _build_cluster_args(**pars): # pylint: disable=too-many-branches,too-many-statements
64-
account_id: str = _utils.get_account_id(boto3_session=pars["boto3_session"])
64+
account_id: str = sts.get_account_id(boto3_session=pars["boto3_session"])
6565
region: str = _utils.get_region_from_session(boto3_session=pars["boto3_session"])
6666

6767
# S3 Logging path
@@ -846,6 +846,7 @@ def build_step(
846846
Examples
847847
--------
848848
>>> import awswrangler as wr
849+
>>> steps = []
849850
>>> for cmd in ['echo "Hello"', "ls -la"]:
850851
... steps.append(wr.emr.build_step(name=cmd, command=cmd))
851852
>>> wr.emr.submit_steps(cluster_id="cluster-id", steps=steps)

awswrangler/quicksight/_cancel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import boto3 # type: ignore
77

8-
from awswrangler import _utils, exceptions
8+
from awswrangler import _utils, exceptions, sts
99
from awswrangler.quicksight._get_list import get_dataset_id
1010

1111
_logger: logging.Logger = logging.getLogger(__name__)
@@ -51,7 +51,7 @@ def cancel_ingestion(
5151
raise exceptions.InvalidArgument("You must pass a not None name or dataset_id argument.")
5252
session: boto3.Session = _utils.ensure_session(session=boto3_session)
5353
if account_id is None:
54-
account_id = _utils.get_account_id(boto3_session=session)
54+
account_id = sts.get_account_id(boto3_session=session)
5555
if (dataset_id is None) and (dataset_name is not None):
5656
dataset_id = get_dataset_id(name=dataset_name, account_id=account_id, boto3_session=session)
5757
client: boto3.client = _utils.client(service_name="quicksight", session=session)

awswrangler/quicksight/_create.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import boto3 # type: ignore
88

9-
from awswrangler import _utils, exceptions
9+
from awswrangler import _utils, exceptions, sts
1010
from awswrangler.quicksight._get_list import get_data_source_arn, get_dataset_id
1111
from awswrangler.quicksight._utils import extract_athena_query_columns, extract_athena_table_columns
1212

@@ -157,7 +157,7 @@ def create_athena_data_source(
157157
session: boto3.Session = _utils.ensure_session(session=boto3_session)
158158
client: boto3.client = _utils.client(service_name="quicksight", session=session)
159159
if account_id is None:
160-
account_id = _utils.get_account_id(boto3_session=session)
160+
account_id = sts.get_account_id(boto3_session=session)
161161
args: Dict[str, Any] = {
162162
"AwsAccountId": account_id,
163163
"DataSourceId": name,
@@ -282,7 +282,7 @@ def create_athena_dataset(
282282
session: boto3.Session = _utils.ensure_session(session=boto3_session)
283283
client: boto3.client = _utils.client(service_name="quicksight", session=session)
284284
if account_id is None:
285-
account_id = _utils.get_account_id(boto3_session=session)
285+
account_id = sts.get_account_id(boto3_session=session)
286286
if (data_source_arn is None) and (data_source_name is not None):
287287
data_source_arn = get_data_source_arn(name=data_source_name, account_id=account_id, boto3_session=session)
288288
if sql is not None:
@@ -379,7 +379,7 @@ def create_ingestion(
379379
"""
380380
session: boto3.Session = _utils.ensure_session(session=boto3_session)
381381
if account_id is None:
382-
account_id = _utils.get_account_id(boto3_session=session)
382+
account_id = sts.get_account_id(boto3_session=session)
383383
if (dataset_name is None) and (dataset_id is None):
384384
raise exceptions.InvalidArgument("You must pass a not None dataset_name or dataset_id argument.")
385385
if (dataset_id is None) and (dataset_name is not None):

awswrangler/quicksight/_delete.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import boto3 # type: ignore
77

8-
from awswrangler import _utils, exceptions
8+
from awswrangler import _utils, exceptions, sts
99
from awswrangler.quicksight._get_list import (
1010
get_dashboard_id,
1111
get_data_source_id,
@@ -25,7 +25,7 @@ def _delete(
2525
) -> None:
2626
session: boto3.Session = _utils.ensure_session(session=boto3_session)
2727
if account_id is None:
28-
account_id = _utils.get_account_id(boto3_session=session)
28+
account_id = sts.get_account_id(boto3_session=session)
2929
client: boto3.client = _utils.client(service_name="quicksight", session=session)
3030
func: Callable = getattr(client, func_name)
3131
func(AwsAccountId=account_id, **kwargs)
@@ -253,7 +253,7 @@ def delete_all_dashboards(account_id: Optional[str] = None, boto3_session: Optio
253253
"""
254254
session: boto3.Session = _utils.ensure_session(session=boto3_session)
255255
if account_id is None:
256-
account_id = _utils.get_account_id(boto3_session=session)
256+
account_id = sts.get_account_id(boto3_session=session)
257257
for dashboard in list_dashboards(account_id=account_id, boto3_session=session):
258258
delete_dashboard(dashboard_id=dashboard["DashboardId"], account_id=account_id, boto3_session=session)
259259

@@ -280,7 +280,7 @@ def delete_all_datasets(account_id: Optional[str] = None, boto3_session: Optiona
280280
"""
281281
session: boto3.Session = _utils.ensure_session(session=boto3_session)
282282
if account_id is None:
283-
account_id = _utils.get_account_id(boto3_session=session)
283+
account_id = sts.get_account_id(boto3_session=session)
284284
for dataset in list_datasets(account_id=account_id, boto3_session=session):
285285
delete_dataset(dataset_id=dataset["DataSetId"], account_id=account_id, boto3_session=session)
286286

@@ -307,7 +307,7 @@ def delete_all_data_sources(account_id: Optional[str] = None, boto3_session: Opt
307307
"""
308308
session: boto3.Session = _utils.ensure_session(session=boto3_session)
309309
if account_id is None:
310-
account_id = _utils.get_account_id(boto3_session=session)
310+
account_id = sts.get_account_id(boto3_session=session)
311311
for data_source in list_data_sources(account_id=account_id, boto3_session=session):
312312
delete_data_source(data_source_id=data_source["DataSourceId"], account_id=account_id, boto3_session=session)
313313

@@ -334,6 +334,6 @@ def delete_all_templates(account_id: Optional[str] = None, boto3_session: Option
334334
"""
335335
session: boto3.Session = _utils.ensure_session(session=boto3_session)
336336
if account_id is None:
337-
account_id = _utils.get_account_id(boto3_session=session)
337+
account_id = sts.get_account_id(boto3_session=session)
338338
for template in list_templates(account_id=account_id, boto3_session=session):
339339
delete_template(template_id=template["TemplateId"], account_id=account_id, boto3_session=session)

awswrangler/quicksight/_describe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import boto3 # type: ignore
77

8-
from awswrangler import _utils, exceptions
8+
from awswrangler import _utils, exceptions, sts
99
from awswrangler.quicksight._get_list import get_dashboard_id, get_data_source_id, get_dataset_id
1010

1111
_logger: logging.Logger = logging.getLogger(__name__)
@@ -48,7 +48,7 @@ def describe_dashboard(
4848
raise exceptions.InvalidArgument("You must pass a not None name or dashboard_id argument.")
4949
session: boto3.Session = _utils.ensure_session(session=boto3_session)
5050
if account_id is None:
51-
account_id = _utils.get_account_id(boto3_session=session)
51+
account_id = sts.get_account_id(boto3_session=session)
5252
if (dashboard_id is None) and (name is not None):
5353
dashboard_id = get_dashboard_id(name=name, account_id=account_id, boto3_session=session)
5454
client: boto3.client = _utils.client(service_name="quicksight", session=session)
@@ -92,7 +92,7 @@ def describe_data_source(
9292
raise exceptions.InvalidArgument("You must pass a not None name or data_source_id argument.")
9393
session: boto3.Session = _utils.ensure_session(session=boto3_session)
9494
if account_id is None:
95-
account_id = _utils.get_account_id(boto3_session=session)
95+
account_id = sts.get_account_id(boto3_session=session)
9696
if (data_source_id is None) and (name is not None):
9797
data_source_id = get_data_source_id(name=name, account_id=account_id, boto3_session=session)
9898
client: boto3.client = _utils.client(service_name="quicksight", session=session)
@@ -136,7 +136,7 @@ def describe_data_source_permissions(
136136
raise exceptions.InvalidArgument("You must pass a not None name or data_source_id argument.")
137137
session: boto3.Session = _utils.ensure_session(session=boto3_session)
138138
if account_id is None:
139-
account_id = _utils.get_account_id(boto3_session=session)
139+
account_id = sts.get_account_id(boto3_session=session)
140140
if (data_source_id is None) and (name is not None):
141141
data_source_id = get_data_source_id(name=name, account_id=account_id, boto3_session=session)
142142
client: boto3.client = _utils.client(service_name="quicksight", session=session)
@@ -180,7 +180,7 @@ def describe_dataset(
180180
raise exceptions.InvalidArgument("You must pass a not None name or dataset_id argument.")
181181
session: boto3.Session = _utils.ensure_session(session=boto3_session)
182182
if account_id is None:
183-
account_id = _utils.get_account_id(boto3_session=session)
183+
account_id = sts.get_account_id(boto3_session=session)
184184
if (dataset_id is None) and (name is not None):
185185
dataset_id = get_dataset_id(name=name, account_id=account_id, boto3_session=session)
186186
client: boto3.client = _utils.client(service_name="quicksight", session=session)
@@ -227,7 +227,7 @@ def describe_ingestion(
227227
raise exceptions.InvalidArgument("You must pass a not None name or dataset_id argument.")
228228
session: boto3.Session = _utils.ensure_session(session=boto3_session)
229229
if account_id is None:
230-
account_id = _utils.get_account_id(boto3_session=session)
230+
account_id = sts.get_account_id(boto3_session=session)
231231
if (dataset_id is None) and (dataset_name is not None):
232232
dataset_id = get_dataset_id(name=dataset_name, account_id=account_id, boto3_session=session)
233233
client: boto3.client = _utils.client(service_name="quicksight", session=session)

awswrangler/quicksight/_get_list.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import boto3 # type: ignore
1111

12-
from awswrangler import _utils, exceptions
12+
from awswrangler import _utils, exceptions, sts
1313

1414
_logger: logging.Logger = logging.getLogger(__name__)
1515

@@ -23,7 +23,7 @@ def _list(
2323
) -> List[Dict[str, Any]]:
2424
session: boto3.Session = _utils.ensure_session(session=boto3_session)
2525
if account_id is None:
26-
account_id = _utils.get_account_id(boto3_session=session)
26+
account_id = sts.get_account_id(boto3_session=session)
2727
client: boto3.client = _utils.client(service_name="quicksight", session=session)
2828
func: Callable = getattr(client, func_name)
2929
response = func(AwsAccountId=account_id, **kwargs)
@@ -408,7 +408,7 @@ def list_ingestions(
408408
raise exceptions.InvalidArgument("You must pass a not None name or dataset_id argument.")
409409
session: boto3.Session = _utils.ensure_session(session=boto3_session)
410410
if account_id is None:
411-
account_id = _utils.get_account_id(boto3_session=session)
411+
account_id = sts.get_account_id(boto3_session=session)
412412
if (dataset_id is None) and (dataset_name is not None):
413413
dataset_id = get_dataset_id(name=dataset_name, account_id=account_id, boto3_session=session)
414414
return _list(

tests/test_moto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import boto3
44
import botocore
5-
import mock
5+
from unittest import mock
66
import moto
77
import pandas as pd
88
import pytest

0 commit comments

Comments
 (0)