Skip to content

Commit 98d4a68

Browse files
committed
Add retry with decorrelated jitter in Athena calls #465
1 parent f7b7b36 commit 98d4a68

File tree

4 files changed

+57
-19
lines changed

4 files changed

+57
-19
lines changed

awswrangler/_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,14 @@ def check_duplicated_columns(df: pd.DataFrame) -> Any:
271271
)
272272

273273

274-
def try_it(f: Callable[..., Any], ex: Any, base: float = 1.0, max_num_tries: int = 3, **kwargs: Any) -> Any:
274+
def try_it(
275+
f: Callable[..., Any],
276+
ex: Any,
277+
ex_code: Optional[str] = None,
278+
base: float = 1.0,
279+
max_num_tries: int = 3,
280+
**kwargs: Any,
281+
) -> Any:
275282
"""Run function with decorrelated Jitter.
276283
277284
Reference: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
@@ -281,8 +288,11 @@ def try_it(f: Callable[..., Any], ex: Any, base: float = 1.0, max_num_tries: int
281288
try:
282289
return f(**kwargs)
283290
except ex as exception:
291+
if ex_code is not None and hasattr(exception, "response"):
292+
if exception.response["Error"]["Code"] != ex_code:
293+
raise
284294
if i == (max_num_tries - 1):
285-
raise exception
295+
raise
286296
delay = random.uniform(base, delay * 3)
287297
_logger.error("Retrying %s | Fail number %s/%s | Exception: %s", f, i + 1, max_num_tries, exception)
288298
time.sleep(delay)

awswrangler/athena/_utils.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Union, cast
99

1010
import boto3
11+
import botocore.exceptions
1112
import pandas as pd
1213

1314
from awswrangler import _data_types, _utils, exceptions, s3, sts
1415
from awswrangler._config import apply_configs
1516

1617
_QUERY_FINAL_STATES: List[str] = ["FAILED", "SUCCEEDED", "CANCELLED"]
17-
_QUERY_WAIT_POLLING_DELAY: float = 0.2 # SECONDS
18+
_QUERY_WAIT_POLLING_DELAY: float = 0.25 # SECONDS
1819

1920
_logger: logging.Logger = logging.getLogger(__name__)
2021

@@ -91,7 +92,13 @@ def _start_query_execution(
9192

9293
client_athena: boto3.client = _utils.client(service_name="athena", session=session)
9394
_logger.debug("args: \n%s", pprint.pformat(args))
94-
response: Dict[str, Any] = client_athena.start_query_execution(**args)
95+
response: Dict[str, Any] = _utils.try_it(
96+
f=client_athena.start_query_execution,
97+
ex=botocore.exceptions.ClientError,
98+
ex_code="ThrottlingException",
99+
max_num_tries=5,
100+
**args,
101+
)
95102
return cast(str, response["QueryExecutionId"])
96103

97104

@@ -608,7 +615,16 @@ def get_work_group(workgroup: str, boto3_session: Optional[boto3.Session] = None
608615
609616
"""
610617
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
611-
return cast(Dict[str, Any], client_athena.get_work_group(WorkGroup=workgroup))
618+
return cast(
619+
Dict[str, Any],
620+
_utils.try_it(
621+
f=client_athena.get_work_group,
622+
ex=botocore.exceptions.ClientError,
623+
ex_code="ThrottlingException",
624+
max_num_tries=5,
625+
WorkGroup=workgroup,
626+
),
627+
)
612628

613629

614630
def stop_query_execution(query_execution_id: str, boto3_session: Optional[boto3.Session] = None) -> None:
@@ -659,20 +675,20 @@ def wait_query(query_execution_id: str, boto3_session: Optional[boto3.Session] =
659675
>>> res = wr.athena.wait_query(query_execution_id='query-execution-id')
660676
661677
"""
662-
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
663-
response: Dict[str, Any] = client_athena.get_query_execution(QueryExecutionId=query_execution_id)
664-
state: str = response["QueryExecution"]["Status"]["State"]
678+
session: boto3.Session = _utils.ensure_session(session=boto3_session)
679+
response: Dict[str, Any] = get_query_execution(query_execution_id=query_execution_id, boto3_session=session)
680+
state: str = response["Status"]["State"]
665681
while state not in _QUERY_FINAL_STATES:
666682
time.sleep(_QUERY_WAIT_POLLING_DELAY)
667-
response = client_athena.get_query_execution(QueryExecutionId=query_execution_id)
668-
state = response["QueryExecution"]["Status"]["State"]
683+
response = get_query_execution(query_execution_id=query_execution_id, boto3_session=session)
684+
state = response["Status"]["State"]
669685
_logger.debug("state: %s", state)
670-
_logger.debug("StateChangeReason: %s", response["QueryExecution"]["Status"].get("StateChangeReason"))
686+
_logger.debug("StateChangeReason: %s", response["Status"].get("StateChangeReason"))
671687
if state == "FAILED":
672-
raise exceptions.QueryFailed(response["QueryExecution"]["Status"].get("StateChangeReason"))
688+
raise exceptions.QueryFailed(response["Status"].get("StateChangeReason"))
673689
if state == "CANCELLED":
674-
raise exceptions.QueryCancelled(response["QueryExecution"]["Status"].get("StateChangeReason"))
675-
return cast(Dict[str, Any], response["QueryExecution"])
690+
raise exceptions.QueryCancelled(response["Status"].get("StateChangeReason"))
691+
return response
676692

677693

678694
def get_query_execution(query_execution_id: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]:
@@ -699,5 +715,11 @@ def get_query_execution(query_execution_id: str, boto3_session: Optional[boto3.S
699715
700716
"""
701717
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
702-
response: Dict[str, Any] = client_athena.get_query_execution(QueryExecutionId=query_execution_id)
718+
response: Dict[str, Any] = _utils.try_it(
719+
f=client_athena.get_query_execution,
720+
ex=botocore.exceptions.ClientError,
721+
ex_code="ThrottlingException",
722+
max_num_tries=5,
723+
QueryExecutionId=query_execution_id,
724+
)
703725
return cast(Dict[str, Any], response["QueryExecution"])

awswrangler/catalog/_get.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Dict, Iterator, List, Optional, Union, cast
88

99
import boto3
10+
import botocore.exceptions
1011
import pandas as pd
1112

1213
from awswrangler import _utils, exceptions
@@ -511,7 +512,14 @@ def get_connection(
511512
"""
512513
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
513514

514-
res = client_glue.get_connection(**_catalog_id(catalog_id=catalog_id, Name=name, HidePassword=False))["Connection"]
515+
res = _utils.try_it(
516+
f=client_glue.get_connection,
517+
ex=botocore.exceptions.ClientError,
518+
ex_code="ThrottlingException",
519+
max_num_tries=3,
520+
**_catalog_id(catalog_id=catalog_id, Name=name, HidePassword=False),
521+
)["Connection"]
522+
515523
if "ENCRYPTED_PASSWORD" in res["ConnectionProperties"]:
516524
client_kms = _utils.client(service_name="kms", session=boto3_session)
517525
pwd = client_kms.decrypt(CiphertextBlob=base64.b64decode(res["ConnectionProperties"]["ENCRYPTED_PASSWORD"]))[

tests/test_s3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,7 @@ def test_merge_additional_kwargs(path, kms_key_id, s3_additional_kwargs, use_thr
207207
@pytest.mark.parametrize("use_threads", [True, False])
208208
def test_copy(path, path2, use_threads):
209209
df = pd.DataFrame({"id": [1, 2, 3], "par": [1, 2, 3]})
210-
wr.s3.to_parquet(
211-
df=df, path=path, dataset=True, partition_cols=["par"], mode="overwrite", use_threads=use_threads
212-
)
210+
wr.s3.to_parquet(df=df, path=path, dataset=True, partition_cols=["par"], mode="overwrite", use_threads=use_threads)
213211
df = wr.s3.read_parquet(path=path, dataset=True, use_threads=use_threads)
214212
assert df.id.sum() == 6
215213
assert df.par.astype("int").sum() == 6

0 commit comments

Comments
 (0)