2424from sagemaker .spark import defaults
2525from sagemaker .jumpstart import artifacts
2626
27-
2827logger = logging .getLogger (__name__ )
2928
3029ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
@@ -47,6 +46,8 @@ def retrieve(
4746 model_version = None ,
4847 tolerate_vulnerable_model = False ,
4948 tolerate_deprecated_model = False ,
49+ sdk_version = None ,
50+ inference_tool = None ,
5051) -> str :
5152 """Retrieves the ECR URI for the Docker image matching the given arguments.
5253
@@ -88,6 +89,11 @@ def retrieve(
8889 tolerate_deprecated_model (bool): True if deprecated versions of model specifications
8990 should be tolerated without an exception raised. If False, raises an exception
9091 if the version of the model is deprecated. (Default: False).
92+ sdk_version (str): the version of python-sdk that will be used in the image retrieval.
93+ (default: None).
94+ inference_tool (str): the tool that will be used to aid in the inference.
95+ Valid values: "neuron, None"
96+ (default: None).
9197
9298 Returns:
9399 str: The ECR URI for the corresponding SageMaker Docker image.
@@ -100,7 +106,6 @@ def retrieve(
100106 DeprecatedJumpStartModelError: If the version of the model is deprecated.
101107 """
102108 if is_jumpstart_model_input (model_id , model_version ):
103-
104109 return artifacts ._retrieve_image_uri (
105110 model_id ,
106111 model_version ,
@@ -118,9 +123,13 @@ def retrieve(
118123 tolerate_vulnerable_model ,
119124 tolerate_deprecated_model ,
120125 )
121-
122126 if training_compiler_config is None :
123- config = _config_for_framework_and_scope (framework , image_scope , accelerator_type )
127+ _framework = framework
128+ if framework == HUGGING_FACE_FRAMEWORK :
129+ inference_tool = _get_inference_tool (inference_tool , instance_type )
130+ if inference_tool == "neuron" :
131+ _framework = f"{ framework } -{ inference_tool } "
132+ config = _config_for_framework_and_scope (_framework , image_scope , accelerator_type )
124133 elif framework == HUGGING_FACE_FRAMEWORK :
125134 config = _config_for_framework_and_scope (
126135 framework + "-training-compiler" , image_scope , accelerator_type
@@ -129,6 +138,7 @@ def retrieve(
129138 raise ValueError (
130139 "Unsupported Configuration: Training Compiler is only supported with HuggingFace"
131140 )
141+
132142 original_version = version
133143 version = _validate_version_and_set_if_needed (version , config , framework )
134144 version_config = config ["versions" ][_version_for_config (version , config )]
@@ -138,7 +148,6 @@ def retrieve(
138148 full_base_framework_version = version_config ["version_aliases" ].get (
139149 base_framework_version , base_framework_version
140150 )
141-
142151 _validate_arg (full_base_framework_version , list (version_config .keys ()), "base framework" )
143152 version_config = version_config .get (full_base_framework_version )
144153
@@ -161,25 +170,37 @@ def retrieve(
161170 pt_or_tf_version = (
162171 re .compile ("^(pytorch|tensorflow)(.*)$" ).match (base_framework_version ).group (2 )
163172 )
164-
165173 _version = original_version
174+
166175 if repo in [
167176 "huggingface-pytorch-trcomp-training" ,
168177 "huggingface-tensorflow-trcomp-training" ,
169178 ]:
170179 _version = version
180+ if repo in ["huggingface-pytorch-inference-neuron" ]:
181+ if not sdk_version :
182+ sdk_version = _get_latest_versions (version_config ["sdk_versions" ])
183+ container_version = sdk_version + "-" + container_version
184+ if config .get ("version_aliases" ).get (original_version ):
185+ _version = config .get ("version_aliases" )[original_version ]
186+ if (
187+ config .get ("versions" , {})
188+ .get (_version , {})
189+ .get ("version_aliases" , {})
190+ .get (base_framework_version , {})
191+ ):
192+ _base_framework_version = config .get ("versions" )[_version ]["version_aliases" ][
193+ base_framework_version
194+ ]
195+ pt_or_tf_version = (
196+ re .compile ("^(pytorch|tensorflow)(.*)$" ).match (_base_framework_version ).group (2 )
197+ )
171198
172199 tag_prefix = f"{ pt_or_tf_version } -transformers{ _version } "
173-
174200 else :
175201 tag_prefix = version_config .get ("tag_prefix" , version )
176202
177- tag = _format_tag (
178- tag_prefix ,
179- processor ,
180- py_version ,
181- container_version ,
182- )
203+ tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
183204
184205 if _should_auto_select_container_version (instance_type , distribution ):
185206 container_versions = {
@@ -248,6 +269,20 @@ def config_for_framework(framework):
248269 return json .load (f )
249270
250271
272+ def _get_inference_tool (inference_tool , instance_type ):
273+ """Extract the inference tool name from instance type."""
274+ if not inference_tool and instance_type :
275+ match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
276+ if match and match [1 ].startswith ("inf" ):
277+ return "neuron"
278+ return inference_tool
279+
280+
281+ def _get_latest_versions (list_of_versions ):
282+ """Extract the latest version from the input list of available versions."""
283+ return sorted (list_of_versions , reverse = True )[0 ]
284+
285+
251286def _validate_accelerator_type (accelerator_type ):
252287 """Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
253288 if not accelerator_type .startswith ("ml.eia" ) and accelerator_type != "local_sagemaker_notebook" :
@@ -310,6 +345,8 @@ def _processor(instance_type, available_processors):
310345
311346 if instance_type .startswith ("local" ):
312347 processor = "cpu" if instance_type == "local" else "gpu"
348+ elif instance_type .startswith ("neuron" ):
349+ processor = "neuron"
313350 else :
314351 # looks for either "ml.<family>.<size>" or "ml_<family>"
315352 match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
@@ -387,8 +424,10 @@ def _validate_arg(arg, available_options, arg_name):
387424 )
388425
389426
390- def _format_tag (tag_prefix , processor , py_version , container_version ):
427+ def _format_tag (tag_prefix , processor , py_version , container_version , inference_tool = None ):
391428 """Creates a tag for the image URI."""
429+ if inference_tool :
430+ return "-" .join (x for x in (tag_prefix , inference_tool , py_version , container_version ) if x )
392431 return "-" .join (x for x in (tag_prefix , processor , py_version , container_version ) if x )
393432
394433
0 commit comments