Skip to content

Commit 6456883

Browse files
author
malavhs
committed
address comments
1 parent 3fed9f4 commit 6456883

File tree

8 files changed

+19
-39
lines changed

8 files changed

+19
-39
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,6 @@
222222

223223
JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub"
224224

225-
JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub"
226-
227225
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
228226
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
229227

src/sagemaker/jumpstart/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload:
447447
return payloads.retrieve_example(
448448
model_id=self.model_id,
449449
model_version=self.model_version,
450+
hub_arn=self.hub_arn,
450451
model_type=self.model_type,
451452
region=self.region,
452453
tolerate_deprecated_model=self.tolerate_deprecated_model,

tests/integ/sagemaker/jumpstart/conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
2424
HUB_NAME_PREFIX,
2525
JUMPSTART_TAG,
26-
SM_JUMPSTART_PUBLIC_HUB_NAME,
2726
)
2827

2928
from sagemaker.jumpstart.types import (
@@ -37,7 +36,7 @@
3736
get_sm_session,
3837
)
3938

40-
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
39+
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_MODEL_HUB_NAME
4140

4241

4342
def _setup():
@@ -144,7 +143,7 @@ def _delete_hubs(sagemaker_session):
144143
)
145144

146145
for hub in list_hub_response["HubSummaries"]:
147-
if hub["HubName"] != SM_JUMPSTART_PUBLIC_HUB_NAME:
146+
if hub["HubName"] != JUMPSTART_MODEL_HUB_NAME:
148147
# delete all hub contents first
149148
_delete_hub_contents(sagemaker_session, hub["HubName"])
150149
sagemaker_session.delete_hub(hub["HubName"])

tests/integ/sagemaker/jumpstart/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
4141

4242
JUMPSTART_TAG = "JumpStart-SDK-Integ-Test-Suite-Id"
4343

44-
SM_JUMPSTART_PUBLIC_HUB_NAME = "SageMakerPublicHub"
45-
4644
HUB_NAME_PREFIX = "PySDK-HubTest-"
4745

4846
TRAINING_DATASET_MODEL_DICT = {

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def test_jumpstart_gated_model_inference_component_enabled(setup):
229229

230230

231231
@mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning")
232-
def test_instatiating_model(mock_warning_logger, setup):
232+
def test_instantiating_model(mock_warning_logger, setup):
233233

234234
model_id = "catboost-regression-model"
235235

tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949

5050
@pytest.fixture(scope="session")
51-
def add_models():
51+
def add_model_references():
5252
# Create Model References to test in Hub
5353
hub_instance = Hub(
5454
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
@@ -57,27 +57,27 @@ def add_models():
5757
hub_instance.create_model_reference(model_arn=get_public_hub_model_arn(hub_instance, model))
5858

5959

60-
def test_jumpstart_hub_model(setup, add_models):
61-
62-
JUMPSTART_LOGGER.info("starting test")
63-
JUMPSTART_LOGGER.info(f"get identity {get_sm_session().get_caller_identity_arn()}")
60+
def test_jumpstart_hub_model(setup, add_model_references):
6461

6562
model_id = "catboost-classification-model"
6663

64+
sagemaker_session = get_sm_session()
65+
6766
model = JumpStartModel(
6867
model_id=model_id,
69-
role=get_sm_session().get_caller_identity_arn(),
70-
sagemaker_session=get_sm_session(),
68+
role=sagemaker_session.get_caller_identity_arn(),
69+
sagemaker_session=sagemaker_session,
7170
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
7271
)
7372

74-
# uses ml.m5.4xlarge instance
75-
model.deploy(
73+
predictor = model.deploy(
7674
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
7775
)
7876

77+
assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name)
7978

80-
def test_jumpstart_hub_gated_model(setup, add_models):
79+
80+
def test_jumpstart_hub_gated_model(setup, add_model_references):
8181

8282
model_id = "meta-textgeneration-llama-3-2-1b"
8383

@@ -88,23 +88,19 @@ def test_jumpstart_hub_gated_model(setup, add_models):
8888
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
8989
)
9090

91-
# uses ml.g6.xlarge instance
9291
predictor = model.deploy(
9392
accept_eula=True,
9493
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
9594
)
9695

97-
payload = {
98-
"inputs": "some-payload",
99-
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
100-
}
96+
payload = model.retrieve_example_payload()
10197

102-
response = predictor.predict(payload, custom_attributes="accept_eula=true")
98+
response = predictor.predict(payload)
10399

104100
assert response is not None
105101

106102

107-
def test_jumpstart_gated_model_inference_component_enabled(setup, add_models):
103+
def test_jumpstart_gated_model_inference_component_enabled(setup, add_model_references):
108104

109105
model_id = "meta-textgeneration-llama-2-7b"
110106

@@ -125,7 +121,6 @@ def test_jumpstart_gated_model_inference_component_enabled(setup, add_models):
125121
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
126122
)
127123

128-
# uses ml.g5.2xlarge instance
129124
model.deploy(
130125
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
131126
accept_eula=True,
@@ -139,10 +134,7 @@ def test_jumpstart_gated_model_inference_component_enabled(setup, add_models):
139134
hub_arn=hub_arn,
140135
)
141136

142-
payload = {
143-
"inputs": "some-payload",
144-
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
145-
}
137+
payload = model.retrieve_example_payload()
146138

147139
response = predictor.predict(payload)
148140

@@ -156,7 +148,7 @@ def test_jumpstart_gated_model_inference_component_enabled(setup, add_models):
156148
assert model.inference_component_name == predictor.component_name
157149

158150

159-
def test_instatiating_model(setup, add_models):
151+
def test_instantiating_model(setup, add_model_references):
160152

161153
model_id = "catboost-regression-model"
162154

tests/integ/sagemaker/jumpstart/private_hub/test_hub_content.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,14 @@ def test_hub_model_reference(setup):
3131
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
3232
)
3333

34-
# Create Model Reference
3534
create_model_response = hub_instance.create_model_reference(
3635
model_arn=get_public_hub_model_arn(hub_instance, model_id)
3736
)
3837
assert create_model_response is not None
3938

40-
# Describe Model
4139
describe_model_response = hub_instance.describe_model(model_name=model_id)
4240
assert describe_model_response is not None
4341
assert type(describe_model_response) == DescribeHubContentResponse
4442

45-
# Delete Model Reference
4643
delete_model_response = hub_instance.delete_model_reference(model_name=model_id)
4744
assert delete_model_response is not None

tests/integ/sagemaker/jumpstart/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,6 @@ def get_public_hub_model_arn(hub: Hub, model_id: str) -> str:
121121
response = hub.list_sagemaker_public_hub_models(filter=filter_value)
122122

123123
models = response["hub_content_summaries"]
124-
while response["next_token"]:
125-
response = hub.list_sagemaker_public_hub_models(
126-
filter=filter_value, next_token=response["next_token"]
127-
)
128-
models.extend(response["hub_content_summaries"])
129124

130125
return models[0]["hub_content_arn"]
131126

0 commit comments

Comments
 (0)