Skip to content

Commit 3d6e884

Browse files
committed
chore: remove support for ecr spec fallbacks for jumpstart models
1 parent 6333914 commit 3d6e884

File tree

24 files changed

+10512
-2814
lines changed

24 files changed

+10512
-2814
lines changed

src/sagemaker/jumpstart/artifacts/image_uris.py

Lines changed: 8 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,12 @@
1414
from __future__ import absolute_import
1515

1616
from typing import Optional
17-
from sagemaker import image_uris
1817
from sagemaker.jumpstart.constants import (
1918
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2019
)
2120
from sagemaker.jumpstart.enums import (
2221
JumpStartModelType,
2322
JumpStartScriptScope,
24-
ModelFramework,
2523
)
2624
from sagemaker.jumpstart.utils import (
2725
get_region_fallback,
@@ -142,13 +140,11 @@ def _retrieve_image_uri(
142140
ecr_uri = model_specs.hosting_ecr_uri
143141
return ecr_uri
144142

145-
ecr_specs = model_specs.hosting_ecr_specs
146-
if ecr_specs is None:
147-
raise ValueError(
148-
f"No inference ECR configuration found for JumpStart model ID '{model_id}' "
149-
f"with {instance_type} instance type in {region}. "
150-
"Please try another instance type or region."
151-
)
143+
raise ValueError(
144+
f"No inference ECR configuration found for JumpStart model ID '{model_id}' "
145+
f"with {instance_type} instance type in {region}. "
146+
"Please try another instance type or region."
147+
)
152148
elif image_scope == JumpStartScriptScope.TRAINING:
153149
training_instance_type_variants = model_specs.training_instance_type_variants
154150
if training_instance_type_variants:
@@ -161,65 +157,8 @@ def _retrieve_image_uri(
161157
ecr_uri = model_specs.training_ecr_uri
162158
return ecr_uri
163159

164-
ecr_specs = model_specs.training_ecr_specs
165-
if ecr_specs is None:
166-
raise ValueError(
167-
f"No training ECR configuration found for JumpStart model ID '{model_id}' "
168-
f"with {instance_type} instance type in {region}. "
169-
"Please try another instance type or region."
170-
)
171-
if framework is not None and framework != ecr_specs.framework:
172-
raise ValueError(
173-
f"Incorrect container framework '{framework}' for JumpStart model ID '{model_id}' "
174-
f"and version '{model_version}'."
175-
)
176-
177-
if version is not None and version != ecr_specs.framework_version:
178-
raise ValueError(
179-
f"Incorrect container framework version '{version}' for JumpStart model ID "
180-
f"'{model_id}' and version '{model_version}'."
181-
)
182-
183-
if py_version is not None and py_version != ecr_specs.py_version:
184160
raise ValueError(
185-
f"Incorrect python version '{py_version}' for JumpStart model ID '{model_id}' "
186-
f"and version '{model_version}'."
161+
f"No training ECR configuration found for JumpStart model ID '{model_id}' "
162+
f"with {instance_type} instance type in {region}. "
163+
"Please try another instance type or region."
187164
)
188-
189-
base_framework_version_override: Optional[str] = None
190-
version_override: Optional[str] = None
191-
if ecr_specs.framework == ModelFramework.HUGGINGFACE:
192-
base_framework_version_override = ecr_specs.framework_version
193-
version_override = ecr_specs.huggingface_transformers_version
194-
195-
if image_scope == JumpStartScriptScope.TRAINING:
196-
return image_uris.get_training_image_uri(
197-
region=region,
198-
framework=ecr_specs.framework,
199-
framework_version=version_override or ecr_specs.framework_version,
200-
py_version=ecr_specs.py_version,
201-
image_uri=None,
202-
distribution=None,
203-
compiler_config=None,
204-
tensorflow_version=None,
205-
pytorch_version=base_framework_version_override or base_framework_version,
206-
instance_type=instance_type,
207-
)
208-
if base_framework_version_override is not None:
209-
base_framework_version_override = f"pytorch{base_framework_version_override}"
210-
211-
return image_uris.retrieve(
212-
framework=ecr_specs.framework,
213-
region=region,
214-
version=version_override or ecr_specs.framework_version,
215-
py_version=ecr_specs.py_version,
216-
instance_type=instance_type,
217-
hub_arn=hub_arn,
218-
accelerator_type=accelerator_type,
219-
image_scope=image_scope,
220-
container_version=container_version,
221-
distribution=distribution,
222-
base_framework_version=base_framework_version_override or base_framework_version,
223-
training_compiler_config=training_compiler_config,
224-
config_name=config_name,
225-
)

tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ def test_jumpstart_default_hyperparameters(
4646
model_version="*",
4747
sagemaker_session=mock_session,
4848
)
49-
assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"}
49+
assert params == {
50+
"train_only_top_layer": "True",
51+
"epochs": "5",
52+
"learning_rate": "0.001",
53+
"batch_size": "4",
54+
"reinitialize_top_layer": "Auto",
55+
}
5056

5157
patched_get_model_specs.assert_called_once_with(
5258
region=region,
@@ -66,7 +72,13 @@ def test_jumpstart_default_hyperparameters(
6672
model_version="1.*",
6773
sagemaker_session=mock_session,
6874
)
69-
assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"}
75+
assert params == {
76+
"train_only_top_layer": "True",
77+
"epochs": "5",
78+
"learning_rate": "0.001",
79+
"batch_size": "4",
80+
"reinitialize_top_layer": "Auto",
81+
}
7082

7183
patched_get_model_specs.assert_called_once_with(
7284
region=region,
@@ -88,12 +100,14 @@ def test_jumpstart_default_hyperparameters(
88100
sagemaker_session=mock_session,
89101
)
90102
assert params == {
91-
"adam-learning-rate": "0.05",
92-
"batch-size": "4",
93-
"epochs": "3",
94-
"sagemaker_container_log_level": "20",
95-
"sagemaker_program": "transfer_learning.py",
103+
"train_only_top_layer": "True",
104+
"epochs": "5",
105+
"learning_rate": "0.001",
106+
"batch_size": "4",
107+
"reinitialize_top_layer": "Auto",
96108
"sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz",
109+
"sagemaker_program": "transfer_learning.py",
110+
"sagemaker_container_log_level": "20",
97111
}
98112

99113
patched_get_model_specs.assert_called_once_with(

tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError
2222
from sagemaker.jumpstart.types import JumpStartHyperparameter
2323

24-
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
24+
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec, get_spec_from_base_spec
2525

2626
region = "us-west-2"
2727
mock_client = boto3.client("s3")
@@ -34,7 +34,7 @@ def test_jumpstart_validate_provided_hyperparameters(
3434
patched_get_model_specs, patched_validate_model_id_and_get_type
3535
):
3636
def add_options_to_hyperparameter(*largs, **kwargs):
37-
spec = get_spec_from_base_spec(*largs, **kwargs)
37+
spec = get_prototype_model_spec(*largs, **kwargs)
3838
spec.hyperparameters.extend(
3939
[
4040
JumpStartHyperparameter(
@@ -115,7 +115,7 @@ def add_options_to_hyperparameter(*largs, **kwargs):
115115
patched_get_model_specs.side_effect = add_options_to_hyperparameter
116116
patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
117117

118-
model_id, model_version = "pytorch-eqa-bert-base-cased", "*"
118+
model_id, model_version = "mxnet-semseg-fcn-resnet50-ade", "*"
119119
region = "us-west-2"
120120

121121
hyperparameter_to_test = {
@@ -412,7 +412,7 @@ def test_jumpstart_validate_algorithm_hyperparameters(
412412
patched_get_model_specs, patched_validate_model_id_and_get_type
413413
):
414414
def add_options_to_hyperparameter(*largs, **kwargs):
415-
spec = get_spec_from_base_spec(*largs, **kwargs)
415+
spec = get_prototype_model_spec(*largs, **kwargs)
416416
spec.hyperparameters.append(
417417
JumpStartHyperparameter(
418418
{
@@ -429,10 +429,11 @@ def add_options_to_hyperparameter(*largs, **kwargs):
429429
patched_get_model_specs.side_effect = add_options_to_hyperparameter
430430
patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
431431

432-
model_id, model_version = "pytorch-eqa-bert-base-cased", "*"
432+
model_id, model_version = "mxnet-semseg-fcn-resnet50-ade", "*"
433433
region = "us-west-2"
434434

435435
hyperparameter_to_test = {
436+
"train-only-top-layer": "True",
436437
"adam-learning-rate": "0.05",
437438
"batch-size": "4",
438439
"epochs": "3",
@@ -488,13 +489,14 @@ def test_jumpstart_validate_all_hyperparameters(
488489
patched_get_model_specs, patched_validate_model_id_and_get_type
489490
):
490491

491-
patched_get_model_specs.side_effect = get_spec_from_base_spec
492+
patched_get_model_specs.side_effect = get_prototype_model_spec
492493
patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
493494

494-
model_id, model_version = "pytorch-eqa-bert-base-cased", "*"
495+
model_id, model_version = "mxnet-semseg-fcn-resnet50-ade", "*"
495496
region = "us-west-2"
496497

497498
hyperparameter_to_test = {
499+
"train-only-top-layer": "True",
498500
"adam-learning-rate": "0.05",
499501
"batch-size": "4",
500502
"epochs": "3",

tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_jumpstart_catboost_image_uri(patched_get_model_specs, session):
3030
patched_get_model_specs.side_effect = get_prototype_model_spec
3131

3232
model_id, model_version = "catboost-classification-model", "*"
33-
instance_type = "ml.p2.xlarge"
33+
instance_type = "ml.m5.xlarge"
3434
region = "us-west-2"
3535

3636
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version)
@@ -55,7 +55,7 @@ def test_jumpstart_catboost_image_uri(patched_get_model_specs, session):
5555
).serving_image_uri(region, instance_type)
5656

5757
assert uri == framework_class_uri
58-
assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
58+
assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310"
5959

6060
# training
6161
uri = image_uris.retrieve(
@@ -78,4 +78,4 @@ def test_jumpstart_catboost_image_uri(patched_get_model_specs, session):
7878
).training_image_uri(region=region)
7979

8080
assert uri == framework_class_uri
81-
assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38"
81+
assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38"

tests/unit/sagemaker/image_uris/jumpstart/test_common.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_jumpstart_common_image_uri(
4747
image_scope="training",
4848
model_id="pytorch-ic-mobilenet-v2",
4949
model_version="*",
50-
instance_type="ml.p2.xlarge",
50+
instance_type="ml.m5.xlarge",
5151
sagemaker_session=mock_session,
5252
)
5353
patched_get_model_specs.assert_called_once_with(
@@ -70,7 +70,7 @@ def test_jumpstart_common_image_uri(
7070
image_scope="inference",
7171
model_id="pytorch-ic-mobilenet-v2",
7272
model_version="1.*",
73-
instance_type="ml.p2.xlarge",
73+
instance_type="ml.m5.xlarge",
7474
sagemaker_session=mock_session,
7575
)
7676
patched_get_model_specs.assert_called_once_with(
@@ -93,7 +93,7 @@ def test_jumpstart_common_image_uri(
9393
image_scope="training",
9494
model_id="pytorch-ic-mobilenet-v2",
9595
model_version="*",
96-
instance_type="ml.p2.xlarge",
96+
instance_type="ml.m5.xlarge",
9797
sagemaker_session=mock_session,
9898
)
9999
patched_get_model_specs.assert_called_once_with(
@@ -116,7 +116,7 @@ def test_jumpstart_common_image_uri(
116116
image_scope="inference",
117117
model_id="pytorch-ic-mobilenet-v2",
118118
model_version="1.*",
119-
instance_type="ml.p2.xlarge",
119+
instance_type="ml.m5.xlarge",
120120
sagemaker_session=mock_session,
121121
)
122122
patched_get_model_specs.assert_called_once_with(
@@ -137,7 +137,7 @@ def test_jumpstart_common_image_uri(
137137
image_scope="BAD_SCOPE",
138138
model_id="pytorch-ic-mobilenet-v2",
139139
model_version="*",
140-
instance_type="ml.p2.xlarge",
140+
instance_type="ml.m5.xlarge",
141141
)
142142

143143
with pytest.raises(KeyError):
@@ -147,7 +147,7 @@ def test_jumpstart_common_image_uri(
147147
image_scope="training",
148148
model_id="blah",
149149
model_version="*",
150-
instance_type="ml.p2.xlarge",
150+
instance_type="ml.m5.xlarge",
151151
)
152152

153153
with pytest.raises(ValueError):
@@ -157,7 +157,7 @@ def test_jumpstart_common_image_uri(
157157
image_scope="training",
158158
model_id="pytorch-ic-mobilenet-v2",
159159
model_version="*",
160-
instance_type="ml.p2.xlarge",
160+
instance_type="ml.m5.xlarge",
161161
)
162162

163163
with pytest.raises(ValueError):
@@ -166,7 +166,7 @@ def test_jumpstart_common_image_uri(
166166
region="us-west-2",
167167
model_id="pytorch-ic-mobilenet-v2",
168168
model_version="*",
169-
instance_type="ml.p2.xlarge",
169+
instance_type="ml.m5.xlarge",
170170
)
171171

172172
with pytest.raises(ValueError):
@@ -175,7 +175,7 @@ def test_jumpstart_common_image_uri(
175175
region="us-west-2",
176176
image_scope="training",
177177
model_version="*",
178-
instance_type="ml.p2.xlarge",
178+
instance_type="ml.m5.xlarge",
179179
)
180180

181181
with pytest.raises(ValueError):
@@ -184,5 +184,5 @@ def test_jumpstart_common_image_uri(
184184
framework=None,
185185
image_scope="training",
186186
model_id="pytorch-ic-mobilenet-v2",
187-
instance_type="ml.p2.xlarge",
187+
instance_type="ml.m5.xlarge",
188188
)

tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def test_jumpstart_huggingface_image_uri(patched_get_model_specs, session):
2828
patched_get_model_specs.side_effect = get_prototype_model_spec
2929

3030
model_id, model_version = "huggingface-spc-bert-base-cased", "*"
31-
instance_type = "ml.p2.xlarge"
31+
instance_type = "ml.m5.xlarge"
32+
training_instance_type = "ml.p3.2xlarge"
3233
region = "us-west-2"
3334

3435
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version)
@@ -55,7 +56,7 @@ def test_jumpstart_huggingface_image_uri(patched_get_model_specs, session):
5556

5657
assert (
5758
uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:"
58-
"1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04"
59+
"1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04"
5960
)
6061

6162
# training
@@ -65,7 +66,7 @@ def test_jumpstart_huggingface_image_uri(patched_get_model_specs, session):
6566
image_scope="training",
6667
model_id=model_id,
6768
model_version=model_version,
68-
instance_type=instance_type,
69+
instance_type=training_instance_type,
6970
)
7071

7172
framework_class_uri = HuggingFace(
@@ -75,7 +76,7 @@ def test_jumpstart_huggingface_image_uri(patched_get_model_specs, session):
7576
entry_point="some_entry_point",
7677
transformers_version=model_specs.training_ecr_specs.huggingface_transformers_version,
7778
pytorch_version=model_specs.training_ecr_specs.framework_version,
78-
instance_type=instance_type,
79+
instance_type=training_instance_type,
7980
instance_count=1,
8081
sagemaker_session=session,
8182
).training_image_uri(region=region)

tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_jumpstart_lightgbm_image_uri(patched_get_model_specs, session):
2828
patched_get_model_specs.side_effect = get_prototype_model_spec
2929

3030
model_id, model_version = "lightgbm-classification-model", "*"
31-
instance_type = "ml.p2.xlarge"
31+
instance_type = "ml.m5.xlarge"
3232
region = "us-west-2"
3333

3434
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version)
@@ -53,7 +53,7 @@ def test_jumpstart_lightgbm_image_uri(patched_get_model_specs, session):
5353
).serving_image_uri(region, instance_type)
5454

5555
assert uri == framework_class_uri
56-
assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
56+
assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310"
5757

5858
# training
5959
uri = image_uris.retrieve(
@@ -76,4 +76,4 @@ def test_jumpstart_lightgbm_image_uri(patched_get_model_specs, session):
7676
).training_image_uri(region=region)
7777

7878
assert uri == framework_class_uri
79-
assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38"
79+
assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38"

0 commit comments

Comments
 (0)