Skip to content

Commit 1326608

Browse files
authored
chore: remove support for ecr spec fallbacks for jumpstart models (#4943)
* chore: remove support for ecr spec fallbacks for jumpstart models * fix: formatting issues * fix: integ tests * fix: integ tests * fix: flake8 * chore: emit log when legacy fields used to get jumpstart image uri * fix: typo
1 parent 048d1f1 commit 1326608

File tree

29 files changed

+10637
-2875
lines changed

29 files changed

+10637
-2875
lines changed

src/sagemaker/image_uris.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from packaging.version import Version
2222

2323
from sagemaker import utils
24-
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
24+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER
2525
from sagemaker.jumpstart.enums import JumpStartModelType
2626
from sagemaker.jumpstart.utils import is_jumpstart_model_input
2727
from sagemaker.spark import defaults
@@ -154,23 +154,27 @@ def retrieve(
154154
)
155155

156156
if is_jumpstart_model_input(model_id, model_version):
157+
if non_none_fields := {
158+
key: value
159+
for key, value in args.items()
160+
if key in {"version", "framework", "container_version", "py_version"}
161+
and value is not None
162+
}:
163+
JUMPSTART_LOGGER.info(
164+
"Ignoring the following arguments when retrieving image uri "
165+
"for JumpStart model id '%s': %s",
166+
model_id,
167+
str(non_none_fields),
168+
)
157169
return artifacts._retrieve_image_uri(
158-
model_id,
159-
model_version,
160-
image_scope,
161-
hub_arn,
162-
framework,
163-
region,
164-
version,
165-
py_version,
166-
instance_type,
167-
accelerator_type,
168-
container_version,
169-
distribution,
170-
base_framework_version,
171-
training_compiler_config,
172-
tolerate_vulnerable_model,
173-
tolerate_deprecated_model,
170+
model_id=model_id,
171+
model_version=model_version,
172+
image_scope=image_scope,
173+
hub_arn=hub_arn,
174+
region=region,
175+
instance_type=instance_type,
176+
tolerate_vulnerable_model=tolerate_vulnerable_model,
177+
tolerate_deprecated_model=tolerate_deprecated_model,
174178
sagemaker_session=sagemaker_session,
175179
config_name=config_name,
176180
model_type=model_type,

src/sagemaker/jumpstart/artifacts/image_uris.py

Lines changed: 10 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,12 @@
1414
from __future__ import absolute_import
1515

1616
from typing import Optional
17-
from sagemaker import image_uris
1817
from sagemaker.jumpstart.constants import (
1918
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2019
)
2120
from sagemaker.jumpstart.enums import (
2221
JumpStartModelType,
2322
JumpStartScriptScope,
24-
ModelFramework,
2523
)
2624
from sagemaker.jumpstart.utils import (
2725
get_region_fallback,
@@ -35,16 +33,8 @@ def _retrieve_image_uri(
3533
model_version: str,
3634
image_scope: str,
3735
hub_arn: Optional[str] = None,
38-
framework: Optional[str] = None,
3936
region: Optional[str] = None,
40-
version: Optional[str] = None,
41-
py_version: Optional[str] = None,
4237
instance_type: Optional[str] = None,
43-
accelerator_type: Optional[str] = None,
44-
container_version: Optional[str] = None,
45-
distribution: Optional[str] = None,
46-
base_framework_version: Optional[str] = None,
47-
training_compiler_config: Optional[str] = None,
4838
tolerate_vulnerable_model: bool = False,
4939
tolerate_deprecated_model: bool = False,
5040
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -66,30 +56,11 @@ def _retrieve_image_uri(
6656
image_scope (str): The image type, i.e. what it is used for.
6757
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
6858
``image_scope`` is ignored.
69-
framework (str): The name of the framework or algorithm.
7059
region (str): The AWS region. (Default: None).
71-
version (str): The framework or algorithm version. This is required if there is
72-
more than one supported version for the given framework or algorithm.
73-
(Default: None).
74-
py_version (str): The Python version. This is required if there is
75-
more than one supported Python version for the given framework version.
7660
instance_type (str): The SageMaker instance type. For supported types, see
7761
https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if
7862
there are different images for different processor types.
7963
(Default: None).
80-
accelerator_type (str): Elastic Inference accelerator type. For more, see
81-
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
82-
(Default: None).
83-
container_version (str): the version of docker image.
84-
Ideally the value of parameter should be created inside the framework.
85-
For custom use, see the list of supported container versions:
86-
https://github.com/aws/deep-learning-containers/blob/master/available_images.md.
87-
(Default: None).
88-
distribution (dict): A dictionary with information on how to run distributed training.
89-
(Default: None).
90-
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
91-
A configuration class for the SageMaker Training Compiler.
92-
(Default: None).
9364
tolerate_vulnerable_model (bool): True if vulnerable versions of model
9465
specifications should be tolerated (exception not raised). If False, raises an
9566
exception if the script used by this version of the model has dependencies with known
@@ -142,14 +113,12 @@ def _retrieve_image_uri(
142113
ecr_uri = model_specs.hosting_ecr_uri
143114
return ecr_uri
144115

145-
ecr_specs = model_specs.hosting_ecr_specs
146-
if ecr_specs is None:
147-
raise ValueError(
148-
f"No inference ECR configuration found for JumpStart model ID '{model_id}' "
149-
f"with {instance_type} instance type in {region}. "
150-
"Please try another instance type or region."
151-
)
152-
elif image_scope == JumpStartScriptScope.TRAINING:
116+
raise ValueError(
117+
f"No inference ECR configuration found for JumpStart model ID '{model_id}' "
118+
f"with {instance_type} instance type in {region}. "
119+
"Please try another instance type or region."
120+
)
121+
if image_scope == JumpStartScriptScope.TRAINING:
153122
training_instance_type_variants = model_specs.training_instance_type_variants
154123
if training_instance_type_variants:
155124
image_uri = training_instance_type_variants.get_image_uri(
@@ -161,65 +130,10 @@ def _retrieve_image_uri(
161130
ecr_uri = model_specs.training_ecr_uri
162131
return ecr_uri
163132

164-
ecr_specs = model_specs.training_ecr_specs
165-
if ecr_specs is None:
166-
raise ValueError(
167-
f"No training ECR configuration found for JumpStart model ID '{model_id}' "
168-
f"with {instance_type} instance type in {region}. "
169-
"Please try another instance type or region."
170-
)
171-
if framework is not None and framework != ecr_specs.framework:
172-
raise ValueError(
173-
f"Incorrect container framework '{framework}' for JumpStart model ID '{model_id}' "
174-
f"and version '{model_version}'."
175-
)
176-
177-
if version is not None and version != ecr_specs.framework_version:
178-
raise ValueError(
179-
f"Incorrect container framework version '{version}' for JumpStart model ID "
180-
f"'{model_id}' and version '{model_version}'."
181-
)
182-
183-
if py_version is not None and py_version != ecr_specs.py_version:
184133
raise ValueError(
185-
f"Incorrect python version '{py_version}' for JumpStart model ID '{model_id}' "
186-
f"and version '{model_version}'."
187-
)
188-
189-
base_framework_version_override: Optional[str] = None
190-
version_override: Optional[str] = None
191-
if ecr_specs.framework == ModelFramework.HUGGINGFACE:
192-
base_framework_version_override = ecr_specs.framework_version
193-
version_override = ecr_specs.huggingface_transformers_version
194-
195-
if image_scope == JumpStartScriptScope.TRAINING:
196-
return image_uris.get_training_image_uri(
197-
region=region,
198-
framework=ecr_specs.framework,
199-
framework_version=version_override or ecr_specs.framework_version,
200-
py_version=ecr_specs.py_version,
201-
image_uri=None,
202-
distribution=None,
203-
compiler_config=None,
204-
tensorflow_version=None,
205-
pytorch_version=base_framework_version_override or base_framework_version,
206-
instance_type=instance_type,
134+
f"No training ECR configuration found for JumpStart model ID '{model_id}' "
135+
f"with {instance_type} instance type in {region}. "
136+
"Please try another instance type or region."
207137
)
208-
if base_framework_version_override is not None:
209-
base_framework_version_override = f"pytorch{base_framework_version_override}"
210138

211-
return image_uris.retrieve(
212-
framework=ecr_specs.framework,
213-
region=region,
214-
version=version_override or ecr_specs.framework_version,
215-
py_version=ecr_specs.py_version,
216-
instance_type=instance_type,
217-
hub_arn=hub_arn,
218-
accelerator_type=accelerator_type,
219-
image_scope=image_scope,
220-
container_version=container_version,
221-
distribution=distribution,
222-
base_framework_version=base_framework_version_override or base_framework_version,
223-
training_compiler_config=training_compiler_config,
224-
config_name=config_name,
225-
)
139+
raise ValueError(f"Invalid scope: {image_scope}")

tests/integ/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
4646
TRAINING_DATASET_MODEL_DICT = {
4747
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
4848
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
49+
("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"),
4950
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
5051
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
5152
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),

tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def package_artifacts(self):
7777

7878
self.model_name = self.get_model_name()
7979

80+
if self.script_uri is None:
81+
print("No script uri provided. Not performing prepack")
82+
return self.model_uri
83+
8084
cache_bucket_uri = f"s3://{get_test_artifact_bucket()}"
8185
repacked_model_uri = "/".join(
8286
[
@@ -147,16 +151,26 @@ def get_model_name(self) -> str:
147151
return f"{non_timestamped_name}{self.suffix}"
148152

149153
def create_model(self) -> None:
154+
primary_container = {
155+
"Image": self.image_uri,
156+
"Mode": "SingleModel",
157+
"Environment": self.environment_variables,
158+
}
159+
if self.repacked_model_uri.endswith(".tar.gz"):
160+
primary_container["ModelDataUrl"] = self.repacked_model_uri
161+
else:
162+
primary_container["ModelDataSource"] = {
163+
"S3DataSource": {
164+
"S3Uri": self.repacked_model_uri,
165+
"S3DataType": "S3Prefix",
166+
"CompressionType": "None",
167+
}
168+
}
150169
self.sagemaker_client.create_model(
151170
ModelName=self.model_name,
152171
EnableNetworkIsolation=True,
153172
ExecutionRoleArn=self.execution_role,
154-
PrimaryContainer={
155-
"Image": self.image_uri,
156-
"ModelDataUrl": self.repacked_model_uri,
157-
"Mode": "SingleModel",
158-
"Environment": self.environment_variables,
159-
},
173+
PrimaryContainer=primary_container,
160174
)
161175

162176
def create_endpoint_config(self) -> None:

tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
InferenceJobLauncher,
1818
)
1919
from sagemaker import environment_variables, image_uris
20-
from sagemaker import script_uris
2120
from sagemaker import model_uris
2221

2322
from tests.integ.sagemaker.jumpstart.constants import InferenceTabularDataname
@@ -31,8 +30,8 @@
3130

3231
def test_jumpstart_inference_retrieve_functions(setup):
3332

34-
model_id, model_version = "catboost-classification-model", "1.0.0"
35-
instance_type = "ml.m5.xlarge"
33+
model_id, model_version = "catboost-classification-model", "2.1.6"
34+
instance_type = "ml.m5.4xlarge"
3635

3736
print("Starting inference...")
3837

@@ -46,13 +45,6 @@ def test_jumpstart_inference_retrieve_functions(setup):
4645
tolerate_vulnerable_model=True,
4746
)
4847

49-
script_uri = script_uris.retrieve(
50-
model_id=model_id,
51-
model_version=model_version,
52-
script_scope="inference",
53-
tolerate_vulnerable_model=True,
54-
)
55-
5648
model_uri = model_uris.retrieve(
5749
model_id=model_id,
5850
model_version=model_version,
@@ -68,7 +60,7 @@ def test_jumpstart_inference_retrieve_functions(setup):
6860

6961
inference_job = InferenceJobLauncher(
7062
image_uri=image_uri,
71-
script_uri=script_uri,
63+
script_uri=None,
7264
model_uri=model_uri,
7365
instance_type=instance_type,
7466
base_name="catboost",

tests/integ/sagemaker/jumpstart/retrieve_uri/test_transfer_learning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
def test_jumpstart_transfer_learning_retrieve_functions(setup):
3535

36-
model_id, model_version = "huggingface-spc-bert-base-cased", "1.0.0"
36+
model_id, model_version = "huggingface-spc-bert-base-cased", "2.0.3"
3737
training_instance_type = "ml.p3.2xlarge"
3838
inference_instance_type = "ml.p2.xlarge"
3939

tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ def test_jumpstart_default_hyperparameters(
4646
model_version="*",
4747
sagemaker_session=mock_session,
4848
)
49-
assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"}
49+
assert params == {
50+
"train_only_top_layer": "True",
51+
"epochs": "5",
52+
"learning_rate": "0.001",
53+
"batch_size": "4",
54+
"reinitialize_top_layer": "Auto",
55+
}
5056

5157
patched_get_model_specs.assert_called_once_with(
5258
region=region,
@@ -66,7 +72,13 @@ def test_jumpstart_default_hyperparameters(
6672
model_version="1.*",
6773
sagemaker_session=mock_session,
6874
)
69-
assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"}
75+
assert params == {
76+
"train_only_top_layer": "True",
77+
"epochs": "5",
78+
"learning_rate": "0.001",
79+
"batch_size": "4",
80+
"reinitialize_top_layer": "Auto",
81+
}
7082

7183
patched_get_model_specs.assert_called_once_with(
7284
region=region,
@@ -88,12 +100,14 @@ def test_jumpstart_default_hyperparameters(
88100
sagemaker_session=mock_session,
89101
)
90102
assert params == {
91-
"adam-learning-rate": "0.05",
92-
"batch-size": "4",
93-
"epochs": "3",
94-
"sagemaker_container_log_level": "20",
95-
"sagemaker_program": "transfer_learning.py",
103+
"train_only_top_layer": "True",
104+
"epochs": "5",
105+
"learning_rate": "0.001",
106+
"batch_size": "4",
107+
"reinitialize_top_layer": "Auto",
96108
"sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz",
109+
"sagemaker_program": "transfer_learning.py",
110+
"sagemaker_container_log_level": "20",
97111
}
98112

99113
patched_get_model_specs.assert_called_once_with(

0 commit comments

Comments
 (0)