Skip to content

Commit bfeb2c0

Browse files
author
malavhs
committed
linting
1 parent 8910f50 commit bfeb2c0

File tree

6 files changed

+64
-57
lines changed

6 files changed

+64
-57
lines changed

src/sagemaker/jumpstart/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,7 @@ def _get_deployment_configs(
10441044
sagemaker_session=self.sagemaker_session,
10451045
region=self.region,
10461046
model_version=self.model_version,
1047-
hub_arn=self.hub_arn
1047+
hub_arn=self.hub_arn,
10481048
)
10491049

10501050
deployment_config_metadata = DeploymentConfigMetadata(

tests/integ/sagemaker/jumpstart/conftest.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,14 @@ def _setup():
4848
test_hub_description = "PySDK Integ Test Private Hub"
4949
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: test_suit_id})
5050
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: test_hub_name})
51-
hub = Hub(hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session())
51+
hub = Hub(
52+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
53+
)
5254
hub.create(description=test_hub_description)
5355
describe_hub_response = hub.describe()
5456
JUMPSTART_LOGGER.info(f"Describe Hub {describe_hub_response}")
5557

58+
5659
def _teardown():
5760
print("Tearing down...")
5861

@@ -133,38 +136,35 @@ def _teardown():
133136

134137
# delete private hubs
135138
_delete_hubs(sagemaker_session)
136-
139+
137140

138141
def _delete_hubs(sagemaker_session):
139-
#list Hubs created by PySDK integration tests
142+
# list Hubs created by PySDK integration tests
140143
list_hub_response = sagemaker_session.list_hubs(name_contains=HUB_NAME_PREFIX)
141144

142-
for hub in list_hub_response['HubSummaries']:
143-
if hub['HubName'] != SM_JUMPSTART_PUBLIC_HUB_NAME:
144-
#delete all hub contents first
145-
_delete_hub_contents(sagemaker_session, hub['HubName'])
145+
for hub in list_hub_response["HubSummaries"]:
146+
if hub["HubName"] != SM_JUMPSTART_PUBLIC_HUB_NAME:
147+
# delete all hub contents first
148+
_delete_hub_contents(sagemaker_session, hub["HubName"])
146149
JUMPSTART_LOGGER.info(f"Deleting {hub['HubName']}")
147-
sagemaker_session.delete_hub(hub['HubName'])
150+
sagemaker_session.delete_hub(hub["HubName"])
148151

149152

150153
def _delete_hub_contents(sagemaker_session, test_hub_name):
151-
#list hub_contents for the given hub
154+
# list hub_contents for the given hub
152155
list_hub_content_response = sagemaker_session.list_hub_contents(
153-
hub_name=test_hub_name,
154-
hub_content_type=HubContentType.MODEL_REFERENCE.value
156+
hub_name=test_hub_name, hub_content_type=HubContentType.MODEL_REFERENCE.value
155157
)
156158
JUMPSTART_LOGGER.info(f"Listing HubContents {list_hub_content_response}")
157159

158-
#delete hub_contents for the given hub
159-
for models in list_hub_content_response['HubContentSummaries']:
160+
# delete hub_contents for the given hub
161+
for models in list_hub_content_response["HubContentSummaries"]:
160162
sagemaker_session.delete_hub_content_reference(
161-
hub_name=test_hub_name,
163+
hub_name=test_hub_name,
162164
hub_content_type=HubContentType.MODEL_REFERENCE.value,
163-
hub_content_name=models['HubContentName']
165+
hub_content_name=models["HubContentName"],
164166
)
165167

166-
167-
168168

169169
@pytest.fixture(scope="session", autouse=True)
170170
def setup(request):

tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,16 @@
5050
"catboost-regression-model",
5151
}
5252

53+
5354
@pytest.fixture(scope="module")
5455
def add_models():
5556
# Create Model References to test in Hub
56-
hub_instance = Hub(hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session())
57+
hub_instance = Hub(
58+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
59+
)
5760
for model in TEST_MODEL_IDS:
58-
hub_instance.create_model_reference(
59-
model_arn = get_public_hub_model_arn(hub_instance, model)
60-
)
61+
hub_instance.create_model_reference(model_arn=get_public_hub_model_arn(hub_instance, model))
62+
6163

6264
def test_jumpstart_hub_model(setup, add_models):
6365

@@ -70,14 +72,15 @@ def test_jumpstart_hub_model(setup, add_models):
7072
model_id=model_id,
7173
role=get_sm_session().get_caller_identity_arn(),
7274
sagemaker_session=get_sm_session(),
73-
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
75+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
7476
)
7577

7678
# uses ml.m5.4xlarge instance
7779
model.deploy(
7880
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
7981
)
8082

83+
8184
def test_jumpstart_hub_gated_model(setup, add_models):
8285

8386
model_id = "meta-textgeneration-llama-3-2-1b"
@@ -86,7 +89,7 @@ def test_jumpstart_hub_gated_model(setup, add_models):
8689
model_id=model_id,
8790
role=get_sm_session().get_caller_identity_arn(),
8891
sagemaker_session=get_sm_session(),
89-
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
92+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
9093
)
9194

9295
# uses ml.g6.xlarge instance
@@ -104,6 +107,7 @@ def test_jumpstart_hub_gated_model(setup, add_models):
104107

105108
assert response is not None
106109

110+
107111
def test_jumpstart_gated_model_inference_component_enabled(setup, add_models):
108112

109113
model_id = "meta-textgeneration-llama-2-7b"
@@ -115,14 +119,14 @@ def test_jumpstart_gated_model_inference_component_enabled(setup, add_models):
115119
sagemaker_session = get_sm_session()
116120

117121
hub_arn = generate_hub_arn_for_init_kwargs(
118-
hub_name=hub_name, region=region, session=sagemaker_session
119-
)
122+
hub_name=hub_name, region=region, session=sagemaker_session
123+
)
120124

121125
model = JumpStartModel(
122126
model_id=model_id,
123127
role=get_sm_session().get_caller_identity_arn(),
124128
sagemaker_session=sagemaker_session,
125-
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
129+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
126130
)
127131

128132
# uses ml.g5.2xlarge instance
@@ -136,7 +140,7 @@ def test_jumpstart_gated_model_inference_component_enabled(setup, add_models):
136140
endpoint_name=model.endpoint_name,
137141
sagemaker_session=sagemaker_session,
138142
tolerate_vulnerable_model=True,
139-
hub_arn=hub_arn
143+
hub_arn=hub_arn,
140144
)
141145

142146
payload = {
@@ -149,13 +153,13 @@ def test_jumpstart_gated_model_inference_component_enabled(setup, add_models):
149153
assert response is not None
150154

151155
model = JumpStartModel.attach(
152-
predictor.endpoint_name,
153-
sagemaker_session=sagemaker_session,
154-
hub_name=hub_name)
156+
predictor.endpoint_name, sagemaker_session=sagemaker_session, hub_name=hub_name
157+
)
155158
assert model.model_id == model_id
156159
assert model.endpoint_name == predictor.endpoint_name
157160
assert model.inference_component_name == predictor.component_name
158161

162+
159163
def test_instatiating_model(setup, add_models):
160164

161165
model_id = "catboost-regression-model"
@@ -166,10 +170,9 @@ def test_instatiating_model(setup, add_models):
166170
model_id=model_id,
167171
role=get_sm_session().get_caller_identity_arn(),
168172
sagemaker_session=get_sm_session(),
169-
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
173+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
170174
)
171175

172176
elapsed_time = time.perf_counter() - start_time
173177

174178
assert elapsed_time <= MAX_INIT_TIME_SECONDS
175-

tests/integ/sagemaker/jumpstart/private_hub/test_hub.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,29 @@
1414
HUB_NAME_PREFIX,
1515
)
1616

17+
1718
@pytest.fixture
1819
def hub_instance():
19-
HUB_NAME=f"{HUB_NAME_PREFIX}-{get_test_suite_id()}"
20+
HUB_NAME = f"{HUB_NAME_PREFIX}-{get_test_suite_id()}"
2021
hub = Hub(HUB_NAME, sagemaker_session=get_sm_session())
2122
yield hub
2223

24+
2325
def test_private_hub(setup, hub_instance):
24-
#Createhub
26+
# Createhub
2527
create_hub_response = hub_instance.create(
26-
description="This is a Test Private Hub.",
27-
display_name="malavhs Test hub",
28-
search_keywords=["jumpstart-sdk-integ-test"],
28+
description="This is a Test Private Hub.",
29+
display_name="malavhs Test hub",
30+
search_keywords=["jumpstart-sdk-integ-test"],
2931
)
3032

31-
#Create Hub Verifications
33+
# Create Hub Verifications
3234
assert create_hub_response is not None
3335

34-
#Describe Hub
36+
# Describe Hub
3537
hub_description = hub_instance.describe()
3638
assert hub_description is not None
3739

38-
#Delete Hub
40+
# Delete Hub
3941
delete_hub_response = hub_instance.delete()
40-
assert delete_hub_response is not None
42+
assert delete_hub_response is not None

tests/integ/sagemaker/jumpstart/private_hub/test_hub_content.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from tests.integ.sagemaker.jumpstart.utils import (
77
get_sm_session,
88
)
9-
from tests.integ.sagemaker.jumpstart.utils import (
10-
get_public_hub_model_arn
11-
)
9+
from tests.integ.sagemaker.jumpstart.utils import get_public_hub_model_arn
1210
from tests.integ.sagemaker.jumpstart.constants import (
1311
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
1412
)
@@ -18,21 +16,21 @@
1816
def test_hub_model_reference(setup):
1917
model_id = "meta-textgenerationneuron-llama-3-2-1b-instruct"
2018

21-
hub_instance = Hub(hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session())
19+
hub_instance = Hub(
20+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
21+
)
2222

23-
#Create Model Reference
23+
# Create Model Reference
2424
create_model_response = hub_instance.create_model_reference(
25-
model_arn = get_public_hub_model_arn(hub_instance, model_id)
25+
model_arn=get_public_hub_model_arn(hub_instance, model_id)
2626
)
27-
assert create_model_response is not None
27+
assert create_model_response is not None
2828

29-
#Describe Model
30-
describe_model_response = hub_instance.describe_model(
31-
model_name = model_id
32-
)
29+
# Describe Model
30+
describe_model_response = hub_instance.describe_model(model_name=model_id)
3331
assert describe_model_response is not None
3432
assert type(describe_model_response) == DescribeHubContentResponse
3533

36-
#Delete Model Reference
34+
# Delete Model Reference
3735
delete_model_response = hub_instance.delete_model_reference(model_name=model_id)
38-
assert delete_model_response is not None
36+
assert delete_model_response is not None

tests/integ/sagemaker/jumpstart/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,20 @@ def get_tabular_data(data_filename: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
115115
def download_file(local_download_path, s3_bucket, s3_key, s3_client) -> None:
116116
s3_client.download_file(s3_bucket, s3_key, local_download_path)
117117

118+
118119
def get_public_hub_model_arn(hub: Hub, model_id: str) -> str:
119120
filter_value = f"model_id == {model_id}"
120121
response = hub.list_sagemaker_public_hub_models(filter=filter_value)
121122

122123
models = response["hub_content_summaries"]
123124
while response["next_token"]:
124-
response = hub.list_sagemaker_public_hub_models(filter=filter_value, next_token=response["next_token"])
125+
response = hub.list_sagemaker_public_hub_models(
126+
filter=filter_value, next_token=response["next_token"]
127+
)
125128
models.extend(response["hub_content_summaries"])
126129

127-
return models[0]['hub_content_arn']
130+
return models[0]["hub_content_arn"]
131+
128132

129133
class EndpointInvoker:
130134
def __init__(

0 commit comments

Comments
 (0)