Skip to content

Commit 8910f50

Browse files
author
malavhs
committed
tests: Implement integration tests covering JumpStart PrivateHub workflows
1 parent ec89f7d commit 8910f50

File tree

10 files changed

+328
-2
lines changed

10 files changed

+328
-2
lines changed

src/sagemaker/jumpstart/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,13 +1036,15 @@ def _get_deployment_configs(
10361036
image_uri=image_uri,
10371037
region=self.region,
10381038
model_version=self.model_version,
1039+
hub_arn=self.hub_arn,
10391040
)
10401041
deploy_kwargs = get_deploy_kwargs(
10411042
model_id=self.model_id,
10421043
instance_type=instance_type_to_use,
10431044
sagemaker_session=self.sagemaker_session,
10441045
region=self.region,
10451046
model_version=self.model_version,
1047+
hub_arn=self.hub_arn
10461048
)
10471049

10481050
deployment_config_metadata = DeploymentConfigMetadata(

tests/integ/sagemaker/jumpstart/conftest.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,50 @@
1616
import boto3
1717
import pytest
1818
from botocore.config import Config
19+
from sagemaker.jumpstart.hub.hub import Hub
1920
from sagemaker.session import Session
2021
from tests.integ.sagemaker.jumpstart.constants import (
2122
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
23+
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
24+
HUB_NAME_PREFIX,
2225
JUMPSTART_TAG,
26+
SM_JUMPSTART_PUBLIC_HUB_NAME,
2327
)
2428

29+
from sagemaker.jumpstart.types import (
30+
HubContentType,
31+
)
32+
33+
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
2534

2635
from tests.integ.sagemaker.jumpstart.utils import (
2736
get_test_artifact_bucket,
2837
get_test_suite_id,
38+
get_sm_session,
2939
)
3040

3141
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
3242

3343

3444
def _setup():
3545
print("Setting up...")
36-
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: get_test_suite_id()})
37-
46+
test_suit_id = get_test_suite_id()
47+
test_hub_name = f"{HUB_NAME_PREFIX}{test_suit_id}"
48+
test_hub_description = "PySDK Integ Test Private Hub"
49+
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: test_suit_id})
50+
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: test_hub_name})
51+
hub = Hub(hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session())
52+
hub.create(description=test_hub_description)
53+
describe_hub_response = hub.describe()
54+
JUMPSTART_LOGGER.info(f"Describe Hub {describe_hub_response}")
3855

3956
def _teardown():
4057
print("Tearing down...")
4158

4259
test_cache_bucket = get_test_artifact_bucket()
4360

4461
test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
62+
test_hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
4563

4664
boto3_session = boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)
4765

@@ -113,6 +131,40 @@ def _teardown():
113131
bucket = s3_resource.Bucket(test_cache_bucket)
114132
bucket.objects.filter(Prefix=test_suite_id + "/").delete()
115133

134+
# delete private hubs
135+
_delete_hubs(sagemaker_session)
136+
137+
138+
def _delete_hubs(sagemaker_session):
139+
#list Hubs created by PySDK integration tests
140+
list_hub_response = sagemaker_session.list_hubs(name_contains=HUB_NAME_PREFIX)
141+
142+
for hub in list_hub_response['HubSummaries']:
143+
if hub['HubName'] != SM_JUMPSTART_PUBLIC_HUB_NAME:
144+
#delete all hub contents first
145+
_delete_hub_contents(sagemaker_session, hub['HubName'])
146+
JUMPSTART_LOGGER.info(f"Deleting {hub['HubName']}")
147+
sagemaker_session.delete_hub(hub['HubName'])
148+
149+
150+
def _delete_hub_contents(sagemaker_session, test_hub_name):
151+
#list hub_contents for the given hub
152+
list_hub_content_response = sagemaker_session.list_hub_contents(
153+
hub_name=test_hub_name,
154+
hub_content_type=HubContentType.MODEL_REFERENCE.value
155+
)
156+
JUMPSTART_LOGGER.info(f"Listing HubContents {list_hub_content_response}")
157+
158+
#delete hub_contents for the given hub
159+
for models in list_hub_content_response['HubContentSummaries']:
160+
sagemaker_session.delete_hub_content_reference(
161+
hub_name=test_hub_name,
162+
hub_content_type=HubContentType.MODEL_REFERENCE.value,
163+
hub_content_name=models['HubContentName']
164+
)
165+
166+
167+
116168

117169
@pytest.fixture(scope="session", autouse=True)
118170
def setup(request):

tests/integ/sagemaker/jumpstart/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,13 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
3737

3838
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID = "JUMPSTART_SDK_TEST_SUITE_ID"
3939

40+
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME = "JUMPSTART_SDK_TEST_HUB_NAME"
41+
4042
JUMPSTART_TAG = "JumpStart-SDK-Integ-Test-Suite-Id"
4143

44+
SM_JUMPSTART_PUBLIC_HUB_NAME = "SageMakerPublicHub"
45+
46+
HUB_NAME_PREFIX = "PySDK-HubTest-"
4247

4348
TRAINING_DATASET_MODEL_DICT = {
4449
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@
5656

5757
def test_non_prepacked_jumpstart_model(setup):
5858

59+
JUMPSTART_LOGGER.info("starting test")
60+
JUMPSTART_LOGGER.info(f"get identity {get_sm_session().get_caller_identity_arn()}")
61+
5962
model_id = "catboost-classification-model"
6063

6164
model = JumpStartModel(

tests/integ/sagemaker/jumpstart/private_hub/__init__.py

Whitespace-only changes.

tests/integ/sagemaker/jumpstart/private_hub/model/__init__.py

Whitespace-only changes.
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import io
16+
import os
17+
import sys
18+
import time
19+
from unittest import mock
20+
import logging
21+
22+
import pytest
23+
from sagemaker.enums import EndpointType
24+
from sagemaker.jumpstart.hub.hub import Hub
25+
from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs
26+
from sagemaker.predictor import retrieve_default
27+
28+
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
29+
30+
import tests.integ
31+
32+
from sagemaker.jumpstart.model import JumpStartModel
33+
from tests.integ.sagemaker.jumpstart.constants import (
34+
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
35+
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
36+
JUMPSTART_TAG,
37+
)
38+
from tests.integ.sagemaker.jumpstart.utils import (
39+
get_public_hub_model_arn,
40+
get_sm_session,
41+
)
42+
43+
MAX_INIT_TIME_SECONDS = 5
44+
45+
TEST_MODEL_IDS = {
46+
"catboost-classification-model",
47+
"huggingface-txt2img-conflictx-complex-lineart",
48+
"meta-textgeneration-llama-2-7b",
49+
"meta-textgeneration-llama-3-2-1b",
50+
"catboost-regression-model",
51+
}
52+
53+
@pytest.fixture(scope="module")
54+
def add_models():
55+
# Create Model References to test in Hub
56+
hub_instance = Hub(hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session())
57+
for model in TEST_MODEL_IDS:
58+
hub_instance.create_model_reference(
59+
model_arn = get_public_hub_model_arn(hub_instance, model)
60+
)
61+
62+
def test_jumpstart_hub_model(setup, add_models):
63+
64+
JUMPSTART_LOGGER.info("starting test")
65+
JUMPSTART_LOGGER.info(f"get identity {get_sm_session().get_caller_identity_arn()}")
66+
67+
model_id = "catboost-classification-model"
68+
69+
model = JumpStartModel(
70+
model_id=model_id,
71+
role=get_sm_session().get_caller_identity_arn(),
72+
sagemaker_session=get_sm_session(),
73+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
74+
)
75+
76+
# uses ml.m5.4xlarge instance
77+
model.deploy(
78+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
79+
)
80+
81+
def test_jumpstart_hub_gated_model(setup, add_models):
82+
83+
model_id = "meta-textgeneration-llama-3-2-1b"
84+
85+
model = JumpStartModel(
86+
model_id=model_id,
87+
role=get_sm_session().get_caller_identity_arn(),
88+
sagemaker_session=get_sm_session(),
89+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
90+
)
91+
92+
# uses ml.g6.xlarge instance
93+
predictor = model.deploy(
94+
accept_eula=True,
95+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
96+
)
97+
98+
payload = {
99+
"inputs": "some-payload",
100+
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
101+
}
102+
103+
response = predictor.predict(payload, custom_attributes="accept_eula=true")
104+
105+
assert response is not None
106+
107+
def test_jumpstart_gated_model_inference_component_enabled(setup, add_models):
108+
109+
model_id = "meta-textgeneration-llama-2-7b"
110+
111+
hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
112+
113+
region = tests.integ.test_region()
114+
115+
sagemaker_session = get_sm_session()
116+
117+
hub_arn = generate_hub_arn_for_init_kwargs(
118+
hub_name=hub_name, region=region, session=sagemaker_session
119+
)
120+
121+
model = JumpStartModel(
122+
model_id=model_id,
123+
role=get_sm_session().get_caller_identity_arn(),
124+
sagemaker_session=sagemaker_session,
125+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
126+
)
127+
128+
# uses ml.g5.2xlarge instance
129+
model.deploy(
130+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
131+
accept_eula=True,
132+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
133+
)
134+
135+
predictor = retrieve_default(
136+
endpoint_name=model.endpoint_name,
137+
sagemaker_session=sagemaker_session,
138+
tolerate_vulnerable_model=True,
139+
hub_arn=hub_arn
140+
)
141+
142+
payload = {
143+
"inputs": "some-payload",
144+
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
145+
}
146+
147+
response = predictor.predict(payload)
148+
149+
assert response is not None
150+
151+
model = JumpStartModel.attach(
152+
predictor.endpoint_name,
153+
sagemaker_session=sagemaker_session,
154+
hub_name=hub_name)
155+
assert model.model_id == model_id
156+
assert model.endpoint_name == predictor.endpoint_name
157+
assert model.inference_component_name == predictor.component_name
158+
159+
def test_instatiating_model(setup, add_models):
160+
161+
model_id = "catboost-regression-model"
162+
163+
start_time = time.perf_counter()
164+
165+
JumpStartModel(
166+
model_id=model_id,
167+
role=get_sm_session().get_caller_identity_arn(),
168+
sagemaker_session=get_sm_session(),
169+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
170+
)
171+
172+
elapsed_time = time.perf_counter() - start_time
173+
174+
assert elapsed_time <= MAX_INIT_TIME_SECONDS
175+
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import pytest
2+
import os
3+
from unittest.mock import MagicMock, patch
4+
from sagemaker.jumpstart.hub.hub import Hub
5+
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
6+
7+
from tests.integ.sagemaker.jumpstart.utils import (
8+
get_sm_session,
9+
)
10+
from tests.integ.sagemaker.jumpstart.utils import (
11+
get_test_suite_id,
12+
)
13+
from tests.integ.sagemaker.jumpstart.constants import (
14+
HUB_NAME_PREFIX,
15+
)
16+
17+
@pytest.fixture
18+
def hub_instance():
19+
HUB_NAME=f"{HUB_NAME_PREFIX}-{get_test_suite_id()}"
20+
hub = Hub(HUB_NAME, sagemaker_session=get_sm_session())
21+
yield hub
22+
23+
def test_private_hub(setup, hub_instance):
24+
#Createhub
25+
create_hub_response = hub_instance.create(
26+
description="This is a Test Private Hub.",
27+
display_name="malavhs Test hub",
28+
search_keywords=["jumpstart-sdk-integ-test"],
29+
)
30+
31+
#Create Hub Verifications
32+
assert create_hub_response is not None
33+
34+
#Describe Hub
35+
hub_description = hub_instance.describe()
36+
assert hub_description is not None
37+
38+
#Delete Hub
39+
delete_hub_response = hub_instance.delete()
40+
assert delete_hub_response is not None
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pytest
2+
import os
3+
from sagemaker.jumpstart.hub.hub import Hub
4+
5+
from sagemaker.jumpstart.hub.interfaces import DescribeHubContentResponse
6+
from tests.integ.sagemaker.jumpstart.utils import (
7+
get_sm_session,
8+
)
9+
from tests.integ.sagemaker.jumpstart.utils import (
10+
get_public_hub_model_arn
11+
)
12+
from tests.integ.sagemaker.jumpstart.constants import (
13+
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
14+
)
15+
import tests
16+
17+
18+
def test_hub_model_reference(setup):
19+
model_id = "meta-textgenerationneuron-llama-3-2-1b-instruct"
20+
21+
hub_instance = Hub(hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session())
22+
23+
#Create Model Reference
24+
create_model_response = hub_instance.create_model_reference(
25+
model_arn = get_public_hub_model_arn(hub_instance, model_id)
26+
)
27+
assert create_model_response is not None
28+
29+
#Describe Model
30+
describe_model_response = hub_instance.describe_model(
31+
model_name = model_id
32+
)
33+
assert describe_model_response is not None
34+
assert type(describe_model_response) == DescribeHubContentResponse
35+
36+
#Delete Model Reference
37+
delete_model_response = hub_instance.delete_model_reference(model_name=model_id)
38+
assert delete_model_response is not None

0 commit comments

Comments
 (0)