Skip to content

Commit 556d120

Browse files
author
malavhs
committed
implement throttle handling
1 parent bac00dd commit 556d120

File tree

3 files changed

+56
-23
lines changed

3 files changed

+56
-23
lines changed

tests/integ/sagemaker/jumpstart/conftest.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
get_test_artifact_bucket,
3535
get_test_suite_id,
3636
get_sm_session,
37+
with_exponential_backoff,
3738
)
3839

3940
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_MODEL_HUB_NAME
@@ -62,6 +63,8 @@ def _teardown():
6263

6364
test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
6465

66+
test_hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
67+
6568
boto3_session = boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)
6669

6770
sagemaker_client = boto3_session.client(
@@ -133,36 +136,28 @@ def _teardown():
133136
bucket.objects.filter(Prefix=test_suite_id + "/").delete()
134137

135138
# delete private hubs
136-
_delete_hubs(sagemaker_session)
139+
_delete_hubs(sagemaker_session, test_hub_name)
137140

138141

139-
def _delete_hubs(sagemaker_session):
140-
# list Hubs created by PySDK integration tests
141-
list_hub_response = sagemaker_session.list_hubs(
142-
name_contains=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
142+
def _delete_hubs(sagemaker_session, hub_name):
143+
# list and delete all hub contents first
144+
list_hub_content_response = sagemaker_session.list_hub_contents(
145+
hub_name=hub_name, hub_content_type=HubContentType.MODEL_REFERENCE.value
143146
)
147+
for model in list_hub_content_response["HubContentSummaries"]:
148+
_delete_hub_contents(sagemaker_session, hub_name, model)
144149

145-
for hub in list_hub_response["HubSummaries"]:
146-
if hub["HubName"] != JUMPSTART_MODEL_HUB_NAME:
147-
# delete all hub contents first
148-
_delete_hub_contents(sagemaker_session, hub["HubName"])
149-
sagemaker_session.delete_hub(hub["HubName"])
150+
sagemaker_session.delete_hub(hub_name)
150151

151152

152-
def _delete_hub_contents(sagemaker_session, test_hub_name):
153-
# list hub_contents for the given hub
154-
list_hub_content_response = sagemaker_session.list_hub_contents(
155-
hub_name=test_hub_name, hub_content_type=HubContentType.MODEL_REFERENCE.value
153+
@with_exponential_backoff()
154+
def _delete_hub_contents(sagemaker_session, hub_name, model):
155+
sagemaker_session.delete_hub_content_reference(
156+
hub_name=hub_name,
157+
hub_content_type=HubContentType.MODEL_REFERENCE.value,
158+
hub_content_name=model["HubContentName"],
156159
)
157160

158-
# delete hub_contents for the given hub
159-
for models in list_hub_content_response["HubContentSummaries"]:
160-
sagemaker_session.delete_hub_content_reference(
161-
hub_name=test_hub_name,
162-
hub_content_type=HubContentType.MODEL_REFERENCE.value,
163-
hub_content_name=models["HubContentName"],
164-
)
165-
166161

167162
@pytest.fixture(scope="session", autouse=True)
168163
def setup(request):

tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
from __future__ import absolute_import
1414

1515
import os
16+
import random
1617
import time
1718

1819
import pytest
1920
from sagemaker.enums import EndpointType
2021
from sagemaker.jumpstart.hub.hub import Hub
2122
from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs
2223
from sagemaker.predictor import retrieve_default
24+
from botocore.exceptions import ClientError
2325

2426
import tests.integ
2527

@@ -32,6 +34,7 @@
3234
from tests.integ.sagemaker.jumpstart.utils import (
3335
get_public_hub_model_arn,
3436
get_sm_session,
37+
with_exponential_backoff,
3538
)
3639

3740
MAX_INIT_TIME_SECONDS = 5
@@ -45,14 +48,20 @@
4548
}
4649

4750

51+
@with_exponential_backoff()
52+
def create_model_reference(hub_instance, model_arn):
53+
hub_instance.create_model_reference(model_arn=model_arn)
54+
55+
4856
@pytest.fixture(scope="session")
4957
def add_model_references():
5058
# Create Model References to test in Hub
5159
hub_instance = Hub(
5260
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
5361
)
5462
for model in TEST_MODEL_IDS:
55-
hub_instance.create_model_reference(model_arn=get_public_hub_model_arn(hub_instance, model))
63+
model_arn = get_public_hub_model_arn(hub_instance, model)
64+
create_model_reference(hub_instance, model_arn)
5665

5766

5867
def test_jumpstart_hub_model(setup, add_model_references):

tests/integ/sagemaker/jumpstart/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414
import functools
1515
import json
1616

17+
import random
18+
import time
1719
import uuid
1820
from typing import Any, Dict, List, Tuple
1921
import boto3
2022
import pandas as pd
2123
import os
2224

2325
from botocore.config import Config
26+
from botocore.exceptions import ClientError
2427
import pytest
2528

2629

@@ -125,6 +128,32 @@ def get_public_hub_model_arn(hub: Hub, model_id: str) -> str:
125128
return models[0]["hub_content_arn"]
126129

127130

131+
def with_exponential_backoff(max_retries=5, initial_delay=1, max_delay=60):
132+
def decorator(func):
133+
@functools.wraps(func)
134+
def wrapper(*args, **kwargs):
135+
retries = 0
136+
while True:
137+
try:
138+
return func(*args, **kwargs)
139+
except ClientError as e:
140+
if retries >= max_retries or e.response["Error"]["Code"] not in [
141+
"ThrottlingException",
142+
"TooManyRequestsException",
143+
]:
144+
raise
145+
delay = min(initial_delay * (2**retries) + random.random(), max_delay)
146+
print(
147+
f"Retrying {func.__name__} in {delay:.2f} seconds... (Attempt {retries + 1}/{max_retries})"
148+
)
149+
time.sleep(delay)
150+
retries += 1
151+
152+
return wrapper
153+
154+
return decorator
155+
156+
128157
class EndpointInvoker:
129158
def __init__(
130159
self,

0 commit comments

Comments
 (0)