@@ -710,7 +710,7 @@ def wait_for_model_package(self, model_package_name, poll=5):
710710 return desc
711711
712712 def create_endpoint_config (self , name , model_name , initial_instance_count , instance_type ,
713- accelerator_type = None , tags = None ):
713+ accelerator_type = None , tags = None , kms_key = None ):
714714 """Create an Amazon SageMaker endpoint configuration.
715715
716716 The endpoint configuration identifies the Amazon SageMaker model (created using the
@@ -738,12 +738,21 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta
738738
739739 tags = tags or []
740740
741- self .sagemaker_client .create_endpoint_config (
742- EndpointConfigName = name ,
743- ProductionVariants = [production_variant (model_name , instance_type , initial_instance_count ,
744- accelerator_type = accelerator_type )],
745- Tags = tags
746- )
741+ request = {
742+ 'EndpointConfigName' : name ,
743+ 'ProductionVariants' : [
744+ production_variant (model_name , instance_type , initial_instance_count ,
745+ accelerator_type = accelerator_type )
746+ ],
747+ }
748+
749+ if tags is not None :
750+ request ['Tags' ] = tags
751+
752+ if kms_key is not None :
753+ request ['KmsKeyId' ] = kms_key
754+
755+ self .sagemaker_client .create_endpoint_config (** request )
747756 return name
748757
749758 def create_endpoint (self , endpoint_name , config_name , tags = None , wait = True ):
@@ -1032,13 +1041,15 @@ def endpoint_from_model_data(self, model_s3_location, deployment_image, initial_
10321041 self .create_endpoint (endpoint_name = name , config_name = name , wait = wait )
10331042 return name
10341043
1035- def endpoint_from_production_variants (self , name , production_variants , tags = None , wait = True ):
1044+ def endpoint_from_production_variants (self , name , production_variants , tags = None , kms_key = None , wait = True ):
10361045 """Create an SageMaker ``Endpoint`` from a list of production variants.
10371046
10381047 Args:
10391048 name (str): The name of the ``Endpoint`` to create.
10401049 production_variants (list[dict[str, str]]): The list of production variants to deploy.
10411050 tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint (default: None).
1051+ kms_key (str): The KMS key that is used to encrypt the data on the storage volume attached
1052+ to the instance hosting the endpoint.
10421053 wait (bool): Whether to wait for the endpoint deployment to complete before returning (default: True).
10431054
10441055 Returns:
@@ -1050,6 +1061,8 @@ def endpoint_from_production_variants(self, name, production_variants, tags=None
10501061 config_options = {'EndpointConfigName' : name , 'ProductionVariants' : production_variants }
10511062 if tags :
10521063 config_options ['Tags' ] = tags
1064+ if kms_key :
1065+ config_options ['KmsKeyId' ] = kms_key
10531066
10541067 self .sagemaker_client .create_endpoint_config (** config_options )
10551068 return self .create_endpoint (endpoint_name = name , config_name = name , tags = tags , wait = wait )
0 commit comments