Skip to content

Commit 0aeca78

Browse files
committed
feat: add integ tests for training JumpStart models in private hub
1 parent 3822454 commit 0aeca78

File tree

6 files changed

+242
-6
lines changed

6 files changed

+242
-6
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,17 +312,23 @@ def _add_hub_access_config_to_kwargs_inputs(
312312
kwargs: JumpStartEstimatorFitKwargs, hub_access_config=None
313313
):
314314
"""Adds HubAccessConfig to kwargs inputs"""
315-
315+
dataset_uri = kwargs.specs.default_training_dataset_uri
316316
if isinstance(kwargs.inputs, str):
317-
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
317+
if dataset_uri is not None and dataset_uri == kwargs.inputs:
318+
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
318319
elif isinstance(kwargs.inputs, TrainingInput):
319-
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
320+
if dataset_uri is not None and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]:
321+
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
320322
elif isinstance(kwargs.inputs, dict):
321323
for k, v in kwargs.inputs.items():
322324
if isinstance(v, str):
323-
kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config)
325+
training_input = TrainingInput(s3_data=v)
326+
if dataset_uri is not None and dataset_uri == v:
327+
training_input.add_hub_access_config(hub_access_config=hub_access_config)
328+
kwargs.inputs[k] = training_input
324329
elif isinstance(kwargs.inputs, TrainingInput):
325-
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
330+
if dataset_uri is not None and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]:
331+
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
326332

327333
return kwargs
328334

src/sagemaker/jumpstart/hub/parsers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,4 +279,10 @@ def make_model_specs_from_describe_hub_content_response(
279279
specs["training_instance_type_variants"] = (
280280
hub_model_document.training_instance_type_variants
281281
)
282+
if hub_model_document.default_training_dataset_uri:
283+
_, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable
284+
hub_model_document.default_training_dataset_uri
285+
)
286+
specs["default_training_dataset_key"] = default_training_dataset_key
287+
specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri
282288
return JumpStartModelSpecs(_to_json(specs), is_hub_content=True)

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
12791279
"hosting_neuron_model_version",
12801280
"hub_content_type",
12811281
"_is_hub_content",
1282+
"default_training_dataset_key",
1283+
"default_training_dataset_uri",
12821284
]
12831285

12841286
_non_serializable_slots = ["_is_hub_content"]
@@ -1462,6 +1464,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
14621464
else None
14631465
)
14641466
self.model_subscription_link = json_obj.get("model_subscription_link")
1467+
self.default_training_dataset_key: Optional[str] = json_obj.get("default_training_dataset_key")
1468+
self.default_training_dataset_uri: Optional[str] = json_obj.get("default_training_dataset_uri")
14651469

14661470
def to_json(self) -> Dict[str, Any]:
14671471
"""Returns json representation of JumpStartMetadataBaseFields object."""

tests/integ/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
4747
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
4848
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
4949
("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"),
50-
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
50+
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"),
5151
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
5252
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
5353
("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"),

tests/integ/sagemaker/jumpstart/private_hub/estimator/__init__.py

Whitespace-only changes.
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
import time
17+
18+
import pytest
19+
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
20+
from sagemaker.jumpstart.hub.hub import Hub
21+
22+
from sagemaker.jumpstart.estimator import JumpStartEstimator
23+
from tests.integ.sagemaker.jumpstart.constants import (
24+
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
25+
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
26+
JUMPSTART_TAG,
27+
)
28+
from tests.integ.sagemaker.jumpstart.utils import (
29+
get_public_hub_model_arn,
30+
get_sm_session,
31+
with_exponential_backoff,
32+
)
33+
from tests.integ.sagemaker.jumpstart.constants import (
34+
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
35+
JUMPSTART_TAG,
36+
)
37+
from tests.integ.sagemaker.jumpstart.utils import (
38+
get_sm_session,
39+
get_training_dataset_for_model_and_version
40+
)
41+
42+
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
43+
44+
MAX_INIT_TIME_SECONDS = 5
45+
46+
TEST_MODEL_IDS = {
47+
"huggingface-spc-bert-base-cased",
48+
"meta-textgeneration-llama-2-7b",
49+
"catboost-regression-model",
50+
}
51+
52+
53+
@with_exponential_backoff()
54+
def create_model_reference(hub_instance, model_arn):
55+
hub_instance.create_model_reference(model_arn=model_arn)
56+
57+
58+
@pytest.fixture(scope="session")
59+
def add_model_references():
60+
# Create Model References to test in Hub
61+
hub_instance = Hub(
62+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
63+
)
64+
for model in TEST_MODEL_IDS:
65+
model_arn = get_public_hub_model_arn(hub_instance, model)
66+
create_model_reference(hub_instance, model_arn)
67+
68+
69+
def test_jumpstart_hub_estimator(setup, add_model_references):
70+
71+
model_id, model_version = "huggingface-spc-bert-base-cased", "*"
72+
73+
sagemaker_session = get_sm_session()
74+
75+
estimator = JumpStartEstimator(
76+
model_id=model_id,
77+
role=sagemaker_session.get_caller_identity_arn(),
78+
sagemaker_session=sagemaker_session,
79+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
80+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
81+
)
82+
83+
estimator.fit(
84+
inputs = {
85+
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
86+
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
87+
}
88+
)
89+
90+
# test that we can create a JumpStartEstimator from existing job with `attach`
91+
estimator = JumpStartEstimator.attach(
92+
training_job_name=estimator.latest_training_job.name,
93+
model_id=model_id,
94+
model_version=model_version,
95+
sagemaker_session=get_sm_session(),
96+
)
97+
98+
# uses ml.p3.2xlarge instance
99+
predictor = estimator.deploy(
100+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
101+
role=get_sm_session().get_caller_identity_arn(),
102+
sagemaker_session=get_sm_session(),
103+
)
104+
105+
response = predictor.predict(["hello", "world"])
106+
107+
assert response is not None
108+
109+
110+
def test_jumpstart_hub_estimator_with_default_session(setup, add_model_references):
111+
model_id, model_version = "huggingface-spc-bert-base-cased", "*"
112+
113+
sagemaker_session = get_sm_session()
114+
115+
estimator = JumpStartEstimator(
116+
model_id=model_id,
117+
role=sagemaker_session.get_caller_identity_arn(),
118+
sagemaker_session=sagemaker_session,
119+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
120+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
121+
)
122+
123+
estimator.fit(
124+
inputs = {
125+
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
126+
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
127+
}
128+
)
129+
130+
131+
# test that we can create a JumpStartEstimator from existing job with `attach`
132+
estimator = JumpStartEstimator.attach(
133+
training_job_name=estimator.latest_training_job.name,
134+
model_id=model_id,
135+
model_version=model_version,
136+
)
137+
138+
# uses ml.p3.2xlarge instance
139+
predictor = estimator.deploy(
140+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
141+
role=get_sm_session().get_caller_identity_arn()
142+
)
143+
144+
response = predictor.predict(["hello", "world"])
145+
146+
assert response is not None
147+
148+
149+
def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
150+
151+
model_id, model_version = "meta-textgeneration-llama-2-7b", "*"
152+
153+
estimator = JumpStartEstimator(
154+
model_id=model_id,
155+
role=get_sm_session().get_caller_identity_arn(),
156+
sagemaker_session=get_sm_session(),
157+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
158+
)
159+
160+
estimator.fit(
161+
accept_eula=True,
162+
inputs = {
163+
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
164+
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
165+
}
166+
)
167+
168+
estimator = JumpStartEstimator.attach(
169+
training_job_name=estimator.latest_training_job.name,
170+
model_id=model_id,
171+
model_version=model_version,
172+
sagemaker_session=get_sm_session(),
173+
)
174+
175+
# uses ml.p3.2xlarge instance
176+
predictor = estimator.deploy(
177+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
178+
role=get_sm_session().get_caller_identity_arn(),
179+
sagemaker_session=get_sm_session(),
180+
)
181+
182+
response = predictor.predict(["hello", "world"])
183+
184+
assert response is not None
185+
186+
187+
def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references):
188+
189+
model_id, model_version = "meta-textgeneration-llama-2-7b", "*"
190+
191+
estimator = JumpStartEstimator(
192+
model_id=model_id,
193+
role=get_sm_session().get_caller_identity_arn(),
194+
sagemaker_session=get_sm_session(),
195+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
196+
)
197+
with pytest.raises(Exception):
198+
estimator.fit(
199+
inputs = {
200+
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
201+
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
202+
}
203+
)
204+
205+
206+
207+
def test_instantiating_estimator(setup, add_model_references):
208+
209+
model_id = "catboost-regression-model"
210+
211+
start_time = time.perf_counter()
212+
213+
JumpStartEstimator(
214+
model_id=model_id,
215+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
216+
)
217+
218+
elapsed_time = time.perf_counter() - start_time
219+
220+
assert elapsed_time <= MAX_INIT_TIME_SECONDS

0 commit comments

Comments
 (0)