Skip to content

chore: remove support for ecr spec fallbacks for jumpstart models #4943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

38 changes: 21 additions & 17 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from packaging.version import Version

from sagemaker import utils
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart.utils import is_jumpstart_model_input
from sagemaker.spark import defaults
Expand Down Expand Up @@ -154,23 +154,27 @@ def retrieve(
)

if is_jumpstart_model_input(model_id, model_version):
if non_none_fields := {
key: value
for key, value in args.items()
if key in {"version", "framework", "container_version", "py_version"}
and value is not None
}:
JUMPSTART_LOGGER.info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks - I would vote for a warning but your call :)

technically, this is still backward incompatible, but the risk of breaking any customers is really low (it would require cx to have written code than calls the retrieve method w/ both a model_id and a framework, container_version and/or py_version which is unlikely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was thinking of warning but thought it might spam the logs.

i don't see how we can make this not backwards. we can no longer rely on ecr specs for raising exceptions.

"Ignoring the following arguments when retrieving image uri "
"for JumpStart model id '%s': %s",
model_id,
str(non_none_fields),
)
return artifacts._retrieve_image_uri(
model_id,
model_version,
image_scope,
hub_arn,
framework,
region,
version,
py_version,
instance_type,
accelerator_type,
container_version,
distribution,
base_framework_version,
training_compiler_config,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
image_scope=image_scope,
hub_arn=hub_arn,
region=region,
instance_type=instance_type,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
Expand Down
106 changes: 10 additions & 96 deletions src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
from __future__ import absolute_import

from typing import Optional
from sagemaker import image_uris
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.enums import (
JumpStartModelType,
JumpStartScriptScope,
ModelFramework,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
Expand All @@ -35,16 +33,8 @@ def _retrieve_image_uri(
model_version: str,
image_scope: str,
hub_arn: Optional[str] = None,
framework: Optional[str] = None,
region: Optional[str] = None,
version: Optional[str] = None,
py_version: Optional[str] = None,
instance_type: Optional[str] = None,
accelerator_type: Optional[str] = None,
container_version: Optional[str] = None,
distribution: Optional[str] = None,
base_framework_version: Optional[str] = None,
training_compiler_config: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -66,30 +56,11 @@ def _retrieve_image_uri(
image_scope (str): The image type, i.e. what it is used for.
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
``image_scope`` is ignored.
framework (str): The name of the framework or algorithm.
region (str): The AWS region. (Default: None).
version (str): The framework or algorithm version. This is required if there is
more than one supported version for the given framework or algorithm.
(Default: None).
py_version (str): The Python version. This is required if there is
more than one supported Python version for the given framework version.
instance_type (str): The SageMaker instance type. For supported types, see
https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if
there are different images for different processor types.
(Default: None).
accelerator_type (str): Elastic Inference accelerator type. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
(Default: None).
container_version (str): the version of docker image.
Ideally the value of parameter should be created inside the framework.
For custom use, see the list of supported container versions:
https://github.com/aws/deep-learning-containers/blob/master/available_images.md.
(Default: None).
distribution (dict): A dictionary with information on how to run distributed training.
(Default: None).
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
A configuration class for the SageMaker Training Compiler.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down Expand Up @@ -142,14 +113,12 @@ def _retrieve_image_uri(
ecr_uri = model_specs.hosting_ecr_uri
return ecr_uri

ecr_specs = model_specs.hosting_ecr_specs
if ecr_specs is None:
raise ValueError(
f"No inference ECR configuration found for JumpStart model ID '{model_id}' "
f"with {instance_type} instance type in {region}. "
"Please try another instance type or region."
)
elif image_scope == JumpStartScriptScope.TRAINING:
raise ValueError(
f"No inference ECR configuration found for JumpStart model ID '{model_id}' "
f"with {instance_type} instance type in {region}. "
"Please try another instance type or region."
)
if image_scope == JumpStartScriptScope.TRAINING:
training_instance_type_variants = model_specs.training_instance_type_variants
if training_instance_type_variants:
image_uri = training_instance_type_variants.get_image_uri(
Expand All @@ -161,65 +130,10 @@ def _retrieve_image_uri(
ecr_uri = model_specs.training_ecr_uri
return ecr_uri

ecr_specs = model_specs.training_ecr_specs
if ecr_specs is None:
raise ValueError(
f"No training ECR configuration found for JumpStart model ID '{model_id}' "
f"with {instance_type} instance type in {region}. "
"Please try another instance type or region."
)
if framework is not None and framework != ecr_specs.framework:
raise ValueError(
f"Incorrect container framework '{framework}' for JumpStart model ID '{model_id}' "
f"and version '{model_version}'."
)

if version is not None and version != ecr_specs.framework_version:
raise ValueError(
f"Incorrect container framework version '{version}' for JumpStart model ID "
f"'{model_id}' and version '{model_version}'."
)

if py_version is not None and py_version != ecr_specs.py_version:
raise ValueError(
f"Incorrect python version '{py_version}' for JumpStart model ID '{model_id}' "
f"and version '{model_version}'."
)

base_framework_version_override: Optional[str] = None
version_override: Optional[str] = None
if ecr_specs.framework == ModelFramework.HUGGINGFACE:
base_framework_version_override = ecr_specs.framework_version
version_override = ecr_specs.huggingface_transformers_version

if image_scope == JumpStartScriptScope.TRAINING:
return image_uris.get_training_image_uri(
region=region,
framework=ecr_specs.framework,
framework_version=version_override or ecr_specs.framework_version,
py_version=ecr_specs.py_version,
image_uri=None,
distribution=None,
compiler_config=None,
tensorflow_version=None,
pytorch_version=base_framework_version_override or base_framework_version,
instance_type=instance_type,
f"No training ECR configuration found for JumpStart model ID '{model_id}' "
f"with {instance_type} instance type in {region}. "
"Please try another instance type or region."
)
if base_framework_version_override is not None:
base_framework_version_override = f"pytorch{base_framework_version_override}"

return image_uris.retrieve(
framework=ecr_specs.framework,
region=region,
version=version_override or ecr_specs.framework_version,
py_version=ecr_specs.py_version,
instance_type=instance_type,
hub_arn=hub_arn,
accelerator_type=accelerator_type,
image_scope=image_scope,
container_version=container_version,
distribution=distribution,
base_framework_version=base_framework_version_override or base_framework_version,
training_compiler_config=training_compiler_config,
config_name=config_name,
)
raise ValueError(f"Invalid scope: {image_scope}")
1 change: 1 addition & 0 deletions tests/integ/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
TRAINING_DATASET_MODEL_DICT = {
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
Expand Down
26 changes: 20 additions & 6 deletions tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def package_artifacts(self):

self.model_name = self.get_model_name()

if self.script_uri is None:
print("No script uri provided. Not performing prepack")
return self.model_uri

cache_bucket_uri = f"s3://{get_test_artifact_bucket()}"
repacked_model_uri = "/".join(
[
Expand Down Expand Up @@ -147,16 +151,26 @@ def get_model_name(self) -> str:
return f"{non_timestamped_name}{self.suffix}"

def create_model(self) -> None:
primary_container = {
"Image": self.image_uri,
"Mode": "SingleModel",
"Environment": self.environment_variables,
}
if self.repacked_model_uri.endswith(".tar.gz"):
primary_container["ModelDataUrl"] = self.repacked_model_uri
else:
primary_container["ModelDataSource"] = {
"S3DataSource": {
"S3Uri": self.repacked_model_uri,
"S3DataType": "S3Prefix",
"CompressionType": "None",
}
}
self.sagemaker_client.create_model(
ModelName=self.model_name,
EnableNetworkIsolation=True,
ExecutionRoleArn=self.execution_role,
PrimaryContainer={
"Image": self.image_uri,
"ModelDataUrl": self.repacked_model_uri,
"Mode": "SingleModel",
"Environment": self.environment_variables,
},
PrimaryContainer=primary_container,
)

def create_endpoint_config(self) -> None:
Expand Down
14 changes: 3 additions & 11 deletions tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
InferenceJobLauncher,
)
from sagemaker import environment_variables, image_uris
from sagemaker import script_uris
from sagemaker import model_uris

from tests.integ.sagemaker.jumpstart.constants import InferenceTabularDataname
Expand All @@ -31,8 +30,8 @@

def test_jumpstart_inference_retrieve_functions(setup):

model_id, model_version = "catboost-classification-model", "1.0.0"
instance_type = "ml.m5.xlarge"
model_id, model_version = "catboost-classification-model", "2.1.6"
instance_type = "ml.m5.4xlarge"

print("Starting inference...")

Expand All @@ -46,13 +45,6 @@ def test_jumpstart_inference_retrieve_functions(setup):
tolerate_vulnerable_model=True,
)

script_uri = script_uris.retrieve(
model_id=model_id,
model_version=model_version,
script_scope="inference",
tolerate_vulnerable_model=True,
)

model_uri = model_uris.retrieve(
model_id=model_id,
model_version=model_version,
Expand All @@ -68,7 +60,7 @@ def test_jumpstart_inference_retrieve_functions(setup):

inference_job = InferenceJobLauncher(
image_uri=image_uri,
script_uri=script_uri,
script_uri=None,
model_uri=model_uri,
instance_type=instance_type,
base_name="catboost",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

def test_jumpstart_transfer_learning_retrieve_functions(setup):

model_id, model_version = "huggingface-spc-bert-base-cased", "1.0.0"
model_id, model_version = "huggingface-spc-bert-base-cased", "2.0.3"
training_instance_type = "ml.p3.2xlarge"
inference_instance_type = "ml.p2.xlarge"

Expand Down
28 changes: 21 additions & 7 deletions tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ def test_jumpstart_default_hyperparameters(
model_version="*",
sagemaker_session=mock_session,
)
assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"}
assert params == {
"train_only_top_layer": "True",
"epochs": "5",
"learning_rate": "0.001",
"batch_size": "4",
"reinitialize_top_layer": "Auto",
}

patched_get_model_specs.assert_called_once_with(
region=region,
Expand All @@ -66,7 +72,13 @@ def test_jumpstart_default_hyperparameters(
model_version="1.*",
sagemaker_session=mock_session,
)
assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"}
assert params == {
"train_only_top_layer": "True",
"epochs": "5",
"learning_rate": "0.001",
"batch_size": "4",
"reinitialize_top_layer": "Auto",
}

patched_get_model_specs.assert_called_once_with(
region=region,
Expand All @@ -88,12 +100,14 @@ def test_jumpstart_default_hyperparameters(
sagemaker_session=mock_session,
)
assert params == {
"adam-learning-rate": "0.05",
"batch-size": "4",
"epochs": "3",
"sagemaker_container_log_level": "20",
"sagemaker_program": "transfer_learning.py",
"train_only_top_layer": "True",
"epochs": "5",
"learning_rate": "0.001",
"batch_size": "4",
"reinitialize_top_layer": "Auto",
"sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz",
"sagemaker_program": "transfer_learning.py",
"sagemaker_container_log_level": "20",
}

patched_get_model_specs.assert_called_once_with(
Expand Down
Loading
Loading