@@ -133,11 +133,7 @@ def __init__(
133133 py_version (str): Python version you want to use for executing your
134134 model training code. Defaults to ``None``. Required unless
135135 ``image_uri`` is provided.
136- image_uri (str): A Docker image URI. Defaults to None. For serverless
137- inferece, it is required. More image information can be found in
138- `Amazon SageMaker provided algorithms and Deep Learning Containers
139- <https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html>`_.
140- For instance based inference, if not specified, a
136+ image_uri (str): A Docker image URI. Defaults to None. If not specified, a
141137 default image for PyTorch will be used. If ``framework_version``
142138 or ``py_version`` are ``None``, then ``image_uri`` is required. If
143139 also ``None``, then a ``ValueError`` will be raised.
@@ -272,7 +268,7 @@ def deploy(
272268 is not None. Otherwise, return None.
273269 """
274270
275- if not self .image_uri and instance_type .startswith ("ml.inf" ):
271+ if not self .image_uri and instance_type is not None and instance_type .startswith ("ml.inf" ):
276272 self .image_uri = self .serving_image_uri (
277273 region_name = self .sagemaker_session .boto_session .region_name ,
278274 instance_type = instance_type ,
@@ -365,7 +361,9 @@ def register(
365361 drift_check_baselines = drift_check_baselines ,
366362 )
367363
368- def prepare_container_def (self , instance_type = None , accelerator_type = None ):
364+ def prepare_container_def (
365+ self , instance_type = None , accelerator_type = None , serverless_inference_config = None
366+ ):
369367 """A container definition with framework configuration set in model environment variables.
370368
371369 Args:
@@ -374,21 +372,27 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
374372 accelerator_type (str): The Elastic Inference accelerator type to
375373 deploy to the instance for loading and making inferences to the
376374 model.
375+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
376+ Specifies configuration related to serverless endpoint. Instance type is
377+ not provided in serverless inference. So this is used to find image URIs.
377378
378379 Returns:
379380 dict[str, str]: A container definition object usable with the
380381 CreateModel API.
381382 """
382383 deploy_image = self .image_uri
383384 if not deploy_image :
384- if instance_type is None :
385+ if instance_type is None and serverless_inference_config is None :
385386 raise ValueError (
386387 "Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
387388 )
388389
389390 region_name = self .sagemaker_session .boto_session .region_name
390391 deploy_image = self .serving_image_uri (
391- region_name , instance_type , accelerator_type = accelerator_type
392+ region_name ,
393+ instance_type ,
394+ accelerator_type = accelerator_type ,
395+ serverless_inference_config = serverless_inference_config ,
392396 )
393397
394398 deploy_key_prefix = model_code_key_prefix (self .key_prefix , self .name , deploy_image )
@@ -402,7 +406,13 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
402406 deploy_image , self .repacked_model_data or self .model_data , deploy_env
403407 )
404408
405- def serving_image_uri (self , region_name , instance_type , accelerator_type = None ):
409+ def serving_image_uri (
410+ self ,
411+ region_name ,
412+ instance_type = None ,
413+ accelerator_type = None ,
414+ serverless_inference_config = None ,
415+ ):
406416 """Create a URI for the serving image.
407417
408418 Args:
@@ -412,6 +422,9 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
412422 accelerator_type (str): The Elastic Inference accelerator type to
413423 deploy to the instance for loading and making inferences to the
414424 model.
425+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
426+ Specifies configuration related to serverless endpoint. Instance type is
427+ not provided in serverless inference. So this is used used to determine device type.
415428
416429 Returns:
417430 str: The appropriate image URI based on the given parameters.
@@ -432,4 +445,5 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
432445 accelerator_type = accelerator_type ,
433446 image_scope = "inference" ,
434447 base_framework_version = base_framework_version ,
448+ serverless_inference_config = serverless_inference_config ,
435449 )
0 commit comments