Skip to content

Commit f4fe5a3

Browse files
committed
Add endpoints urls in the global config. #418
1 parent b11bca7 commit f4fe5a3

File tree

5 files changed

+182
-185
lines changed

5 files changed

+182
-185
lines changed

awswrangler/_config.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pandas as pd
99

10-
from awswrangler import _utils, exceptions
10+
from awswrangler import exceptions
1111

1212
_logger: logging.Logger = logging.getLogger(__name__)
1313

@@ -31,15 +31,30 @@ class _ConfigArg(NamedTuple):
3131
"max_cache_seconds": _ConfigArg(dtype=int, nullable=False),
3232
"s3_block_size": _ConfigArg(dtype=int, nullable=False, enforced=True),
3333
"workgroup": _ConfigArg(dtype=str, nullable=False, enforced=True),
34+
# Endpoints URLs
35+
"s3_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True),
36+
"athena_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True),
37+
"sts_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True),
38+
"glue_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True),
39+
"redshift_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True),
40+
"kms_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True),
41+
"emr_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True),
3442
}
3543

3644

37-
class _Config:
45+
class _Config: # pylint: disable=too-many-instance-attributes
3846
"""Wrangler's Configuration class."""
3947

4048
def __init__(self) -> None:
4149
self._loaded_values: Dict[str, _ConfigValueType] = {}
4250
name: str
51+
self.s3_endpoint_url = None
52+
self.athena_endpoint_url = None
53+
self.sts_endpoint_url = None
54+
self.glue_endpoint_url = None
55+
self.redshift_endpoint_url = None
56+
self.kms_endpoint_url = None
57+
self.emr_endpoint_url = None
4358
for name in _CONFIG_ARGS:
4459
self._load_config(name=name)
4560

@@ -125,7 +140,10 @@ def __getitem__(self, item: str) -> _ConfigValueType:
125140

126141
def _reset_item(self, item: str) -> None:
127142
if item in self._loaded_values:
128-
del self._loaded_values[item]
143+
if item.endswith("_endpoint_url"):
144+
self._loaded_values[item] = None
145+
else:
146+
del self._loaded_values[item]
129147
self._load_config(name=item)
130148

131149
def _repr_html_(self) -> Any:
@@ -224,6 +242,75 @@ def workgroup(self) -> Optional[str]:
224242
def workgroup(self, value: Optional[str]) -> None:
225243
self._set_config_value(key="workgroup", value=value)
226244

245+
@property
246+
def s3_endpoint_url(self) -> Optional[str]:
247+
"""Property s3_endpoint_url."""
248+
return cast(Optional[str], self["s3_endpoint_url"])
249+
250+
@s3_endpoint_url.setter
251+
def s3_endpoint_url(self, value: Optional[str]) -> None:
252+
self._set_config_value(key="s3_endpoint_url", value=value)
253+
254+
@property
255+
def athena_endpoint_url(self) -> Optional[str]:
256+
"""Property athena_endpoint_url."""
257+
return cast(Optional[str], self["athena_endpoint_url"])
258+
259+
@athena_endpoint_url.setter
260+
def athena_endpoint_url(self, value: Optional[str]) -> None:
261+
self._set_config_value(key="athena_endpoint_url", value=value)
262+
263+
@property
264+
def sts_endpoint_url(self) -> Optional[str]:
265+
"""Property sts_endpoint_url."""
266+
return cast(Optional[str], self["sts_endpoint_url"])
267+
268+
@sts_endpoint_url.setter
269+
def sts_endpoint_url(self, value: Optional[str]) -> None:
270+
self._set_config_value(key="sts_endpoint_url", value=value)
271+
272+
@property
273+
def glue_endpoint_url(self) -> Optional[str]:
274+
"""Property glue_endpoint_url."""
275+
return cast(Optional[str], self["glue_endpoint_url"])
276+
277+
@glue_endpoint_url.setter
278+
def glue_endpoint_url(self, value: Optional[str]) -> None:
279+
self._set_config_value(key="glue_endpoint_url", value=value)
280+
281+
@property
282+
def redshift_endpoint_url(self) -> Optional[str]:
283+
"""Property redshift_endpoint_url."""
284+
return cast(Optional[str], self["redshift_endpoint_url"])
285+
286+
@redshift_endpoint_url.setter
287+
def redshift_endpoint_url(self, value: Optional[str]) -> None:
288+
self._set_config_value(key="redshift_endpoint_url", value=value)
289+
290+
@property
291+
def kms_endpoint_url(self) -> Optional[str]:
292+
"""Property kms_endpoint_url."""
293+
return cast(Optional[str], self["kms_endpoint_url"])
294+
295+
@kms_endpoint_url.setter
296+
def kms_endpoint_url(self, value: Optional[str]) -> None:
297+
self._set_config_value(key="kms_endpoint_url", value=value)
298+
299+
@property
300+
def emr_endpoint_url(self) -> Optional[str]:
301+
"""Property emr_endpoint_url."""
302+
return cast(Optional[str], self["emr_endpoint_url"])
303+
304+
@emr_endpoint_url.setter
305+
def emr_endpoint_url(self, value: Optional[str]) -> None:
306+
self._set_config_value(key="emr_endpoint_url", value=value)
307+
308+
309+
def _insert_str(text: str, token: str, insert: str) -> str:
310+
"""Insert string into other."""
311+
index: int = text.find(token)
312+
return text[:index] + insert + text[index:]
313+
227314

228315
def _inject_config_doc(doc: Optional[str], available_configs: Tuple[str, ...]) -> str:
229316
if doc is None:
@@ -244,7 +331,7 @@ def _inject_config_doc(doc: Optional[str], available_configs: Tuple[str, ...]) -
244331
" for details.\n"
245332
)
246333
insertion: str = header + args_block + footer + "\n\n"
247-
return _utils.insert_str(text=doc, token="\n Parameters", insert=insertion)
334+
return _insert_str(text=doc, token="\n Parameters", insert=insertion)
248335

249336

250337
def apply_configs(function: Callable[..., Any]) -> Callable[..., Any]:

awswrangler/_utils.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pandas as pd
1717
import psycopg2
1818

19-
from awswrangler import exceptions
19+
from awswrangler import _config, exceptions
2020

2121
_logger: logging.Logger = logging.getLogger(__name__)
2222

@@ -65,14 +65,39 @@ def botocore_config() -> botocore.config.Config:
6565
return botocore.config.Config(retries={"max_attempts": 5}, connect_timeout=10, max_pool_connections=10)
6666

6767

68+
def _get_endpoint_url(service_name: str) -> Optional[str]:
69+
endpoint_url: Optional[str] = None
70+
if service_name == "s3" and _config.config.s3_endpoint_url is not None:
71+
endpoint_url = _config.config.s3_endpoint_url
72+
elif service_name == "athena" and _config.config.athena_endpoint_url is not None:
73+
endpoint_url = _config.config.athena_endpoint_url
74+
elif service_name == "sts" and _config.config.sts_endpoint_url is not None:
75+
endpoint_url = _config.config.sts_endpoint_url
76+
elif service_name == "glue" and _config.config.glue_endpoint_url is not None:
77+
endpoint_url = _config.config.glue_endpoint_url
78+
elif service_name == "redshift" and _config.config.redshift_endpoint_url is not None:
79+
endpoint_url = _config.config.redshift_endpoint_url
80+
elif service_name == "kms" and _config.config.kms_endpoint_url is not None:
81+
endpoint_url = _config.config.kms_endpoint_url
82+
elif service_name == "emr" and _config.config.emr_endpoint_url is not None:
83+
endpoint_url = _config.config.emr_endpoint_url
84+
return endpoint_url
85+
86+
6887
def client(service_name: str, session: Optional[boto3.Session] = None) -> boto3.client:
6988
"""Create a valid boto3.client."""
70-
return ensure_session(session=session).client(service_name=service_name, use_ssl=True, config=botocore_config())
89+
endpoint_url: Optional[str] = _get_endpoint_url(service_name=service_name)
90+
return ensure_session(session=session).client(
91+
service_name=service_name, endpoint_url=endpoint_url, use_ssl=True, config=botocore_config()
92+
)
7193

7294

7395
def resource(service_name: str, session: Optional[boto3.Session] = None) -> boto3.resource:
7496
"""Create a valid boto3.resource."""
75-
return ensure_session(session=session).resource(service_name=service_name, use_ssl=True, config=botocore_config())
97+
endpoint_url: Optional[str] = _get_endpoint_url(service_name=service_name)
98+
return ensure_session(session=session).resource(
99+
service_name=service_name, endpoint_url=endpoint_url, use_ssl=True, config=botocore_config()
100+
)
76101

77102

78103
def parse_path(path: str) -> Tuple[str, str]:
@@ -238,12 +263,6 @@ def ensure_df_is_mutable(df: pd.DataFrame) -> pd.DataFrame:
238263
return df
239264

240265

241-
def insert_str(text: str, token: str, insert: str) -> str:
242-
"""Insert string into other."""
243-
index: int = text.find(token)
244-
return text[:index] + insert + text[index:]
245-
246-
247266
def check_duplicated_columns(df: pd.DataFrame) -> Any:
248267
"""Raise an exception if there are duplicated columns names."""
249268
duplicated: List[str] = df.loc[:, df.columns.duplicated()].columns.to_list()

awswrangler/athena/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def create_athena_bucket(boto3_session: Optional[boto3.Session] = None) -> str:
305305
account_id: str = sts.get_account_id(boto3_session=session)
306306
region_name: str = str(session.region_name).lower()
307307
s3_output = f"s3://aws-athena-query-results-{account_id}-{region_name}/"
308-
s3_resource = session.resource("s3")
308+
s3_resource = _utils.resource(service_name="s3", session=session)
309309
s3_resource.Bucket(s3_output)
310310
return s3_output
311311

tests/test_config.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import logging
22
import os
3+
from unittest.mock import patch
34

5+
import boto3
6+
import botocore
47
import pytest
58

69
import awswrangler as wr
@@ -9,6 +12,26 @@
912
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
1013

1114

15+
def _urls_test(glue_database):
16+
original = botocore.client.ClientCreator.create_client
17+
18+
def wrapper(self, **kwarg):
19+
name = kwarg["service_name"]
20+
url = kwarg["endpoint_url"]
21+
if name == "sts":
22+
assert url == wr.config.sts_endpoint_url
23+
elif name == "athena":
24+
assert url == wr.config.athena_endpoint_url
25+
elif name == "s3":
26+
assert url == wr.config.s3_endpoint_url
27+
elif name == "glue":
28+
assert url == wr.config.glue_endpoint_url
29+
return original(self, **kwarg)
30+
31+
with patch("botocore.client.ClientCreator.create_client", new=wrapper):
32+
wr.athena.read_sql_query(sql="SELECT 1 as col0", database=glue_database)
33+
34+
1235
def test_basics(path, glue_database, glue_table, workgroup0, workgroup1):
1336
args = {"table": glue_table, "path": "", "columns_types": {"col0": "bigint"}}
1437

@@ -71,3 +94,17 @@ def test_basics(path, glue_database, glue_table, workgroup0, workgroup1):
7194
wr.config.reset()
7295
df = wr.athena.read_sql_query(sql="SELECT 1 as col0", database=glue_database)
7396
assert df.query_metadata["WorkGroup"] == workgroup1
97+
98+
# Endpoints URLs
99+
region = boto3.Session().region_name
100+
wr.config.sts_endpoint_url = f"https://sts.{region}.amazonaws.com"
101+
wr.config.s3_endpoint_url = f"https://s3.{region}.amazonaws.com"
102+
wr.config.athena_endpoint_url = f"https://athena.{region}.amazonaws.com"
103+
wr.config.glue_endpoint_url = f"https://glue.{region}.amazonaws.com"
104+
_urls_test(glue_database)
105+
os.environ["WR_STS_ENDPOINT_URL"] = f"https://sts.{region}.amazonaws.com"
106+
os.environ["WR_S3_ENDPOINT_URL"] = f"https://s3.{region}.amazonaws.com"
107+
os.environ["WR_ATHENA_ENDPOINT_URL"] = f"https://athena.{region}.amazonaws.com"
108+
os.environ["WR_GLUE_ENDPOINT_URL"] = f"https://glue.{region}.amazonaws.com"
109+
wr.config.reset()
110+
_urls_test(glue_database)

0 commit comments

Comments
 (0)