14
14
from __future__ import absolute_import
15
15
16
16
from typing import Optional
17
- from sagemaker import image_uris
18
17
from sagemaker .jumpstart .constants import (
19
18
DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
20
19
)
21
20
from sagemaker .jumpstart .enums import (
22
21
JumpStartModelType ,
23
22
JumpStartScriptScope ,
24
- ModelFramework ,
25
23
)
26
24
from sagemaker .jumpstart .utils import (
27
25
get_region_fallback ,
@@ -35,16 +33,8 @@ def _retrieve_image_uri(
35
33
model_version : str ,
36
34
image_scope : str ,
37
35
hub_arn : Optional [str ] = None ,
38
- framework : Optional [str ] = None ,
39
36
region : Optional [str ] = None ,
40
- version : Optional [str ] = None ,
41
- py_version : Optional [str ] = None ,
42
37
instance_type : Optional [str ] = None ,
43
- accelerator_type : Optional [str ] = None ,
44
- container_version : Optional [str ] = None ,
45
- distribution : Optional [str ] = None ,
46
- base_framework_version : Optional [str ] = None ,
47
- training_compiler_config : Optional [str ] = None ,
48
38
tolerate_vulnerable_model : bool = False ,
49
39
tolerate_deprecated_model : bool = False ,
50
40
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
@@ -66,30 +56,11 @@ def _retrieve_image_uri(
66
56
image_scope (str): The image type, i.e. what it is used for.
67
57
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
68
58
``image_scope`` is ignored.
69
- framework (str): The name of the framework or algorithm.
70
59
region (str): The AWS region. (Default: None).
71
- version (str): The framework or algorithm version. This is required if there is
72
- more than one supported version for the given framework or algorithm.
73
- (Default: None).
74
- py_version (str): The Python version. This is required if there is
75
- more than one supported Python version for the given framework version.
76
60
instance_type (str): The SageMaker instance type. For supported types, see
77
61
https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if
78
62
there are different images for different processor types.
79
63
(Default: None).
80
- accelerator_type (str): Elastic Inference accelerator type. For more, see
81
- https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
82
- (Default: None).
83
- container_version (str): the version of docker image.
84
- Ideally the value of parameter should be created inside the framework.
85
- For custom use, see the list of supported container versions:
86
- https://github.com/aws/deep-learning-containers/blob/master/available_images.md.
87
- (Default: None).
88
- distribution (dict): A dictionary with information on how to run distributed training.
89
- (Default: None).
90
- training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
91
- A configuration class for the SageMaker Training Compiler.
92
- (Default: None).
93
64
tolerate_vulnerable_model (bool): True if vulnerable versions of model
94
65
specifications should be tolerated (exception not raised). If False, raises an
95
66
exception if the script used by this version of the model has dependencies with known
@@ -142,14 +113,12 @@ def _retrieve_image_uri(
142
113
ecr_uri = model_specs .hosting_ecr_uri
143
114
return ecr_uri
144
115
145
- ecr_specs = model_specs .hosting_ecr_specs
146
- if ecr_specs is None :
147
- raise ValueError (
148
- f"No inference ECR configuration found for JumpStart model ID '{ model_id } ' "
149
- f"with { instance_type } instance type in { region } . "
150
- "Please try another instance type or region."
151
- )
152
- elif image_scope == JumpStartScriptScope .TRAINING :
116
+ raise ValueError (
117
+ f"No inference ECR configuration found for JumpStart model ID '{ model_id } ' "
118
+ f"with { instance_type } instance type in { region } . "
119
+ "Please try another instance type or region."
120
+ )
121
+ if image_scope == JumpStartScriptScope .TRAINING :
153
122
training_instance_type_variants = model_specs .training_instance_type_variants
154
123
if training_instance_type_variants :
155
124
image_uri = training_instance_type_variants .get_image_uri (
@@ -161,65 +130,10 @@ def _retrieve_image_uri(
161
130
ecr_uri = model_specs .training_ecr_uri
162
131
return ecr_uri
163
132
164
- ecr_specs = model_specs .training_ecr_specs
165
- if ecr_specs is None :
166
- raise ValueError (
167
- f"No training ECR configuration found for JumpStart model ID '{ model_id } ' "
168
- f"with { instance_type } instance type in { region } . "
169
- "Please try another instance type or region."
170
- )
171
- if framework is not None and framework != ecr_specs .framework :
172
- raise ValueError (
173
- f"Incorrect container framework '{ framework } ' for JumpStart model ID '{ model_id } ' "
174
- f"and version '{ model_version } '."
175
- )
176
-
177
- if version is not None and version != ecr_specs .framework_version :
178
- raise ValueError (
179
- f"Incorrect container framework version '{ version } ' for JumpStart model ID "
180
- f"'{ model_id } ' and version '{ model_version } '."
181
- )
182
-
183
- if py_version is not None and py_version != ecr_specs .py_version :
184
133
raise ValueError (
185
- f"Incorrect python version '{ py_version } ' for JumpStart model ID '{ model_id } ' "
186
- f"and version '{ model_version } '."
187
- )
188
-
189
- base_framework_version_override : Optional [str ] = None
190
- version_override : Optional [str ] = None
191
- if ecr_specs .framework == ModelFramework .HUGGINGFACE :
192
- base_framework_version_override = ecr_specs .framework_version
193
- version_override = ecr_specs .huggingface_transformers_version
194
-
195
- if image_scope == JumpStartScriptScope .TRAINING :
196
- return image_uris .get_training_image_uri (
197
- region = region ,
198
- framework = ecr_specs .framework ,
199
- framework_version = version_override or ecr_specs .framework_version ,
200
- py_version = ecr_specs .py_version ,
201
- image_uri = None ,
202
- distribution = None ,
203
- compiler_config = None ,
204
- tensorflow_version = None ,
205
- pytorch_version = base_framework_version_override or base_framework_version ,
206
- instance_type = instance_type ,
134
+ f"No training ECR configuration found for JumpStart model ID '{ model_id } ' "
135
+ f"with { instance_type } instance type in { region } . "
136
+ "Please try another instance type or region."
207
137
)
208
- if base_framework_version_override is not None :
209
- base_framework_version_override = f"pytorch{ base_framework_version_override } "
210
138
211
- return image_uris .retrieve (
212
- framework = ecr_specs .framework ,
213
- region = region ,
214
- version = version_override or ecr_specs .framework_version ,
215
- py_version = ecr_specs .py_version ,
216
- instance_type = instance_type ,
217
- hub_arn = hub_arn ,
218
- accelerator_type = accelerator_type ,
219
- image_scope = image_scope ,
220
- container_version = container_version ,
221
- distribution = distribution ,
222
- base_framework_version = base_framework_version_override or base_framework_version ,
223
- training_compiler_config = training_compiler_config ,
224
- config_name = config_name ,
225
- )
139
+ raise ValueError (f"Invalid scope: { image_scope } " )
0 commit comments