Skip to content
Merged
33 changes: 27 additions & 6 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
JUMPSTART_LOGGER,
TRAINING_ENTRY_POINT_SCRIPT_NAME,
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
JUMPSTART_MODEL_HUB_NAME,
)
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
from sagemaker.jumpstart.factory import model
Expand Down Expand Up @@ -313,16 +314,31 @@ def _add_hub_access_config_to_kwargs_inputs(
):
"""Adds HubAccessConfig to kwargs inputs"""

dataset_uri = kwargs.specs.default_training_dataset_uri
if isinstance(kwargs.inputs, str):
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
if dataset_uri is not None and dataset_uri == kwargs.inputs:
kwargs.inputs = TrainingInput(
s3_data=kwargs.inputs, hub_access_config=hub_access_config
)
elif isinstance(kwargs.inputs, TrainingInput):
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
if (
dataset_uri is not None
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
):
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
elif isinstance(kwargs.inputs, dict):
for k, v in kwargs.inputs.items():
if isinstance(v, str):
kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config)
training_input = TrainingInput(s3_data=v)
if dataset_uri is not None and dataset_uri == v:
training_input.add_hub_access_config(hub_access_config=hub_access_config)
kwargs.inputs[k] = training_input
elif isinstance(kwargs.inputs, TrainingInput):
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
if (
dataset_uri is not None
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
):
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)

return kwargs

Expand Down Expand Up @@ -616,8 +632,13 @@ def _add_model_reference_arn_to_kwargs(

def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
"""Sets model uri in kwargs based on default or override, returns full kwargs."""

if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)):
# hub_arn is by default None unless the user specifies the hub_name
# If no hub_name is specified, it is assumed the public hub
is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False
if (
_model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs))
or is_private_hub
):
default_model_uri = model_uris.retrieve(
model_scope=JumpStartScriptScope.TRAINING,
instance_type=kwargs.instance_type,
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/jumpstart/hub/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
if json_obj.get("ValidationSupported")
else None
)
self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri")
self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase")
self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False))
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
Expand Down Expand Up @@ -671,6 +670,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
)

if self.training_supported:
self.default_training_dataset_uri: Optional[str] = json_obj.get(
"DefaultTrainingDatasetUri"
)
self.training_model_package_artifact_uri: Optional[str] = json_obj.get(
"TrainingModelPackageArtifactUri"
)
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/jumpstart/hub/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,4 +279,10 @@ def make_model_specs_from_describe_hub_content_response(
specs["training_instance_type_variants"] = (
hub_model_document.training_instance_type_variants
)
if hub_model_document.default_training_dataset_uri:
_, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.default_training_dataset_uri
)
specs["default_training_dataset_key"] = default_training_dataset_key
specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri
return JumpStartModelSpecs(_to_json(specs), is_hub_content=True)
33 changes: 29 additions & 4 deletions src/sagemaker/jumpstart/hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo
from sagemaker.jumpstart import constants
from packaging.specifiers import SpecifierSet, InvalidSpecifier
from packaging import version

PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"

Expand Down Expand Up @@ -219,9 +220,12 @@ def get_hub_model_version(
sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION

try:
hub_content_summaries = sagemaker_session.list_hub_content_versions(
hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type
).get("HubContentSummaries")
hub_content_summaries = _list_hub_content_versions_helper(
hub_name=hub_name,
hub_content_name=hub_model_name,
hub_content_type=hub_model_type,
sagemaker_session=sagemaker_session,
)
except Exception as ex:
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")

Expand All @@ -238,13 +242,34 @@ def get_hub_model_version(
raise


def _list_hub_content_versions_helper(
hub_name, hub_content_name, hub_content_type, sagemaker_session
):
all_hub_content_summaries = []
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type
)
all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries"))
while "NextToken" in list_hub_content_versions_response:
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
hub_name=hub_name,
hub_content_name=hub_content_name,
hub_content_type=hub_content_type,
next_token=list_hub_content_versions_response["NextToken"],
)
all_hub_content_summaries.extend(
list_hub_content_versions_response.get("HubContentSummaries")
)
return all_hub_content_summaries


def _get_hub_model_version_for_open_weight_version(
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
) -> str:
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]

if hub_model_version == "*" or hub_model_version is None:
return str(max(available_model_versions))
return str(max(version.parse(v) for v in available_model_versions))

try:
spec = SpecifierSet(f"=={hub_model_version}")
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
"hosting_neuron_model_version",
"hub_content_type",
"_is_hub_content",
"default_training_dataset_key",
"default_training_dataset_uri",
]

_non_serializable_slots = ["_is_hub_content"]
Expand Down Expand Up @@ -1462,6 +1464,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
else None
)
self.model_subscription_link = json_obj.get("model_subscription_link")
self.default_training_dataset_key: Optional[str] = json_obj.get(
"default_training_dataset_key"
)
self.default_training_dataset_uri: Optional[str] = json_obj.get(
"default_training_dataset_uri"
)

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartMetadataBaseFields object."""
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"),
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"),
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import time

import pytest
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.hub.hub import Hub

from sagemaker.jumpstart.estimator import JumpStartEstimator
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket

from tests.integ.sagemaker.jumpstart.constants import (
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
JUMPSTART_TAG,
)
from tests.integ.sagemaker.jumpstart.utils import (
get_public_hub_model_arn,
get_sm_session,
with_exponential_backoff,
get_training_dataset_for_model_and_version,
)

MAX_INIT_TIME_SECONDS = 5

TEST_MODEL_IDS = {
"huggingface-spc-bert-base-cased",
"meta-textgeneration-llama-2-7b",
"catboost-regression-model",
}


@with_exponential_backoff()
def create_model_reference(hub_instance, model_arn):
try:
hub_instance.create_model_reference(model_arn=model_arn)
except Exception:
pass


@pytest.fixture(scope="session")
def add_model_references():
# Create Model References to test in Hub
hub_instance = Hub(
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
)
for model in TEST_MODEL_IDS:
model_arn = get_public_hub_model_arn(hub_instance, model)
create_model_reference(hub_instance, model_arn)


def test_jumpstart_hub_estimator(setup, add_model_references):
model_id, model_version = "huggingface-spc-bert-base-cased", "*"

estimator = JumpStartEstimator(
model_id=model_id,
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
)

estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
}
)

# test that we can create a JumpStartEstimator from existing job with `attach`
estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
)

# uses ml.p3.2xlarge instance
predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
)

response = predictor.predict(["hello", "world"])

assert response is not None


def test_jumpstart_hub_estimator_with_session(setup, add_model_references):

model_id, model_version = "huggingface-spc-bert-base-cased", "*"

sagemaker_session = get_sm_session()

estimator = JumpStartEstimator(
model_id=model_id,
role=sagemaker_session.get_caller_identity_arn(),
sagemaker_session=sagemaker_session,
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)

estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
}
)

# test that we can create a JumpStartEstimator from existing job with `attach`
estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
sagemaker_session=get_sm_session(),
)

# uses ml.p3.2xlarge instance
predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)

response = predictor.predict(["hello", "world"])

assert response is not None


def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):

model_id, model_version = "meta-textgeneration-llama-2-7b", "*"

estimator = JumpStartEstimator(
model_id=model_id,
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
)

estimator.fit(
accept_eula=True,
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
},
)

predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)

payload = {
"inputs": "some-payload",
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
}

response = predictor.predict(payload, custom_attributes="accept_eula=true")

assert response is not None


def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references):

model_id, model_version = "meta-textgeneration-llama-2-7b", "*"

estimator = JumpStartEstimator(
model_id=model_id,
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
)
with pytest.raises(Exception):
estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
}
)


def test_instantiating_estimator(setup, add_model_references):

model_id = "catboost-regression-model"

start_time = time.perf_counter()

JumpStartEstimator(
model_id=model_id,
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
)

elapsed_time = time.perf_counter() - start_time

assert elapsed_time <= MAX_INIT_TIME_SECONDS
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@

@with_exponential_backoff()
def create_model_reference(hub_instance, model_arn):
hub_instance.create_model_reference(model_arn=model_arn)
try:
hub_instance.create_model_reference(model_arn=model_arn)
except Exception:
pass


@pytest.fixture(scope="session")
Expand Down
Loading