3232from sagemaker .inputs import CompilationInput
3333from sagemaker .deprecations import removed_kwargs
3434from sagemaker .predictor import PredictorBase
35+ from sagemaker .serverless import ServerlessInferenceConfig
3536from sagemaker .transformer import Transformer
3637
3738LOGGER = logging .getLogger ("sagemaker" )
@@ -209,7 +210,7 @@ def register(
209210 model_package_arn = model_package .get ("ModelPackageArn" ),
210211 )
211212
212- def _init_sagemaker_session_if_does_not_exist (self , instance_type ):
213+ def _init_sagemaker_session_if_does_not_exist (self , instance_type = None ):
213214 """Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
214215
215216 The type of session object is determined by the instance type.
@@ -688,8 +689,8 @@ def compile(
688689
689690 def deploy (
690691 self ,
691- initial_instance_count ,
692- instance_type ,
692+ initial_instance_count = None ,
693+ instance_type = None ,
693694 serializer = None ,
694695 deserializer = None ,
695696 accelerator_type = None ,
@@ -698,6 +699,7 @@ def deploy(
698699 kms_key = None ,
699700 wait = True ,
700701 data_capture_config = None ,
702+ serverless_inference_config = None ,
701703 ** kwargs ,
702704 ):
703705 """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -715,9 +717,13 @@ def deploy(
715717
716718 Args:
717719 initial_instance_count (int): The initial number of instances to run
718- in the ``Endpoint`` created from this ``Model``.
720+ in the ``Endpoint`` created from this ``Model``. If not using
721+ serverless inference, then it need to be a number larger or equals
722+ to 1 (default: None)
719723 instance_type (str): The EC2 instance type to deploy this Model to.
720- For example, 'ml.p2.xlarge', or 'local' for local mode.
724+ For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
725+ serverless inference, then it is required to deploy a model.
726+ (default: None)
721727 serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
722728 serializer object, used to encode data for an inference endpoint
723729 (default: None). If ``serializer`` is not None, then
@@ -746,7 +752,17 @@ def deploy(
746752 data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
747753 configuration related to Endpoint data capture for use with
748754 Amazon SageMaker Model Monitoring. Default: None.
749-
755+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
756+ Specifies configuration related to serverless endpoint. Use this configuration
757+ when trying to create serverless endpoint and make serverless inference. If
758+ empty object passed through, we will use pre-defined values in
759+ ``ServerlessInferenceConfig`` class to deploy serverless endpoint (default: None)
760+ Raises:
761+ ValueError: If arguments combination check failed in these circumstances:
762+ - If no role is specified or
763+ - If serverless inference config is not specified and instance type and instance
764+ count are also not specified or
765+ - If a wrong type of object is provided as serverless inference config
750766 Returns:
751767 callable[string, sagemaker.session.Session] or None: Invocation of
752768 ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
@@ -758,27 +774,47 @@ def deploy(
758774 if self .role is None :
759775 raise ValueError ("Role can not be null for deploying a model" )
760776
761- if instance_type .startswith ("ml.inf" ) and not self ._is_compiled_model :
777+ is_serverless = serverless_inference_config is not None
778+ if not is_serverless and not (instance_type and initial_instance_count ):
779+ raise ValueError (
780+ "Must specify instance type and instance count unless using serverless inference"
781+ )
782+
783+ if is_serverless and not isinstance (serverless_inference_config , ServerlessInferenceConfig ):
784+ raise ValueError (
785+ "serverless_inference_config needs to be a ServerlessInferenceConfig object"
786+ )
787+
788+ if instance_type and instance_type .startswith ("ml.inf" ) and not self ._is_compiled_model :
762789 LOGGER .warning (
763790 "Your model is not compiled. Please compile your model before using Inferentia."
764791 )
765792
766- compiled_model_suffix = "-" .join (instance_type .split ("." )[:- 1 ])
767- if self ._is_compiled_model :
793+ compiled_model_suffix = None if is_serverless else "-" .join (instance_type .split ("." )[:- 1 ])
794+ if self ._is_compiled_model and not is_serverless :
768795 self ._ensure_base_name_if_needed (self .image_uri )
769796 if self ._base_name is not None :
770797 self ._base_name = "-" .join ((self ._base_name , compiled_model_suffix ))
771798
772799 self ._create_sagemaker_model (instance_type , accelerator_type , tags )
800+
801+ serverless_inference_config_dict = (
802+ serverless_inference_config ._to_request_dict () if is_serverless else None
803+ )
773804 production_variant = sagemaker .production_variant (
774- self .name , instance_type , initial_instance_count , accelerator_type = accelerator_type
805+ self .name ,
806+ instance_type ,
807+ initial_instance_count ,
808+ accelerator_type = accelerator_type ,
809+ serverless_inference_config = serverless_inference_config_dict ,
775810 )
776811 if endpoint_name :
777812 self .endpoint_name = endpoint_name
778813 else :
779814 base_endpoint_name = self ._base_name or utils .base_from_name (self .name )
780- if self ._is_compiled_model and not base_endpoint_name .endswith (compiled_model_suffix ):
781- base_endpoint_name = "-" .join ((base_endpoint_name , compiled_model_suffix ))
815+ if self ._is_compiled_model and not is_serverless :
816+ if not base_endpoint_name .endswith (compiled_model_suffix ):
817+ base_endpoint_name = "-" .join ((base_endpoint_name , compiled_model_suffix ))
782818 self .endpoint_name = utils .name_from_base (base_endpoint_name )
783819
784820 data_capture_config_dict = None
0 commit comments