Skip to content
Merged
24 changes: 19 additions & 5 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,31 @@ def _add_hub_access_config_to_kwargs_inputs(
kwargs: JumpStartEstimatorFitKwargs, hub_access_config=None
):
"""Adds HubAccessConfig to kwargs inputs"""

dataset_uri = kwargs.specs.default_training_dataset_uri
if isinstance(kwargs.inputs, str):
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
if dataset_uri is not None and dataset_uri == kwargs.inputs:
kwargs.inputs = TrainingInput(
s3_data=kwargs.inputs, hub_access_config=hub_access_config
)
elif isinstance(kwargs.inputs, TrainingInput):
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
if (
dataset_uri is not None
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
):
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
elif isinstance(kwargs.inputs, dict):
for k, v in kwargs.inputs.items():
if isinstance(v, str):
kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config)
training_input = TrainingInput(s3_data=v)
if dataset_uri is not None and dataset_uri == v:
training_input.add_hub_access_config(hub_access_config=hub_access_config)
kwargs.inputs[k] = training_input
elif isinstance(kwargs.inputs, TrainingInput):
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
if (
dataset_uri is not None
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
):
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)

return kwargs

Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/jumpstart/hub/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,4 +279,10 @@ def make_model_specs_from_describe_hub_content_response(
specs["training_instance_type_variants"] = (
hub_model_document.training_instance_type_variants
)
if hub_model_document.default_training_dataset_uri:
_, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.default_training_dataset_uri
)
specs["default_training_dataset_key"] = default_training_dataset_key
specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri
return JumpStartModelSpecs(_to_json(specs), is_hub_content=True)
8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
"hosting_neuron_model_version",
"hub_content_type",
"_is_hub_content",
"default_training_dataset_key",
"default_training_dataset_uri",
]

_non_serializable_slots = ["_is_hub_content"]
Expand Down Expand Up @@ -1462,6 +1464,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
else None
)
self.model_subscription_link = json_obj.get("model_subscription_link")
self.default_training_dataset_key: Optional[str] = json_obj.get(
"default_training_dataset_key"
)
self.default_training_dataset_uri: Optional[str] = json_obj.get(
"default_training_dataset_uri"
)

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartMetadataBaseFields object."""
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"),
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"),
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import time

import pytest
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.hub.hub import Hub

from sagemaker.jumpstart.estimator import JumpStartEstimator
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket

from tests.integ.sagemaker.jumpstart.constants import (
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
JUMPSTART_TAG,
)
from tests.integ.sagemaker.jumpstart.utils import (
get_public_hub_model_arn,
get_sm_session,
with_exponential_backoff,
get_training_dataset_for_model_and_version,
)

MAX_INIT_TIME_SECONDS = 5

TEST_MODEL_IDS = {
"huggingface-spc-bert-base-cased",
"meta-textgeneration-llama-2-7b",
"catboost-regression-model",
}


@with_exponential_backoff()
def create_model_reference(hub_instance, model_arn):
hub_instance.create_model_reference(model_arn=model_arn)


@pytest.fixture(scope="session")
def add_model_references():
# Create Model References to test in Hub
hub_instance = Hub(
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
)
for model in TEST_MODEL_IDS:
model_arn = get_public_hub_model_arn(hub_instance, model)
create_model_reference(hub_instance, model_arn)


def test_jumpstart_hub_estimator(setup, add_model_references):

model_id, model_version = "huggingface-spc-bert-base-cased", "*"

sagemaker_session = get_sm_session()

estimator = JumpStartEstimator(
model_id=model_id,
role=sagemaker_session.get_caller_identity_arn(),
sagemaker_session=sagemaker_session,
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)

estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
}
)

# test that we can create a JumpStartEstimator from existing job with `attach`
estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
sagemaker_session=get_sm_session(),
)

# uses ml.p3.2xlarge instance
predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)

response = predictor.predict(["hello", "world"])

assert response is not None


def test_jumpstart_hub_estimator_with_default_session(setup, add_model_references):
model_id, model_version = "huggingface-spc-bert-base-cased", "*"

sagemaker_session = get_sm_session()

estimator = JumpStartEstimator(
model_id=model_id,
role=sagemaker_session.get_caller_identity_arn(),
sagemaker_session=sagemaker_session,
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)

estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
}
)

# test that we can create a JumpStartEstimator from existing job with `attach`
estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
)

# uses ml.p3.2xlarge instance
predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
role=get_sm_session().get_caller_identity_arn(),
)

response = predictor.predict(["hello", "world"])

assert response is not None


def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):

model_id, model_version = "meta-textgeneration-llama-2-7b", "*"

estimator = JumpStartEstimator(
model_id=model_id,
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)

estimator.fit(
accept_eula=True,
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
},
)

estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
sagemaker_session=get_sm_session(),
)

# uses ml.p3.2xlarge instance
predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)

response = predictor.predict(["hello", "world"])

assert response is not None


def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references):

model_id, model_version = "meta-textgeneration-llama-2-7b", "*"

estimator = JumpStartEstimator(
model_id=model_id,
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)
with pytest.raises(Exception):
estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
}
)


def test_instantiating_estimator(setup, add_model_references):

model_id = "catboost-regression-model"

start_time = time.perf_counter()

JumpStartEstimator(
model_id=model_id,
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)

elapsed_time = time.perf_counter() - start_time

assert elapsed_time <= MAX_INIT_TIME_SECONDS
Loading