@@ -2778,6 +2778,7 @@ def create_model_package_from_containers(
27782778 approval_status = "PendingManualApproval" ,
27792779 description = None ,
27802780 drift_check_baselines = None ,
2781+ customer_metadata_properties = None ,
27812782 ):
27822783 """Get request dictionary for CreateModelPackage API.
27832784
@@ -2803,6 +2804,9 @@ def create_model_package_from_containers(
28032804 or "PendingManualApproval" (default: "PendingManualApproval").
28042805 description (str): Model Package description (default: None).
28052806 drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
2807+ customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
2808+ metadata properties (default: None).
2809+
28062810 """
28072811
28082812 request = get_create_model_package_request (
@@ -2819,7 +2823,17 @@ def create_model_package_from_containers(
28192823 approval_status ,
28202824 description ,
28212825 drift_check_baselines = drift_check_baselines ,
2826+ customer_metadata_properties = customer_metadata_properties ,
28222827 )
2828+ if model_package_group_name is not None :
2829+ try :
2830+ self .sagemaker_client .describe_model_package_group (
2831+ ModelPackageGroupName = request ["ModelPackageGroupName" ]
2832+ )
2833+ except ClientError :
2834+ self .sagemaker_client .create_model_package_group (
2835+ ModelPackageGroupName = request ["ModelPackageGroupName" ]
2836+ )
28232837 return self .sagemaker_client .create_model_package (** request )
28242838
28252839 def wait_for_model_package (self , model_package_name , poll = 5 ):
@@ -4120,6 +4134,7 @@ def get_model_package_args(
41204134 tags = None ,
41214135 container_def_list = None ,
41224136 drift_check_baselines = None ,
4137+ customer_metadata_properties = None ,
41234138):
41244139 """Get arguments for create_model_package method.
41254140
@@ -4148,6 +4163,8 @@ def get_model_package_args(
41484163 (default: None).
41494164 container_def_list (list): A list of container defintiions (default: None).
41504165 drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
4166+ customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
4167+ metadata properties (default: None).
41514168 Returns:
41524169 dict: A dictionary of method argument names and values.
41534170 """
@@ -4185,6 +4202,8 @@ def get_model_package_args(
41854202 model_package_args ["description" ] = description
41864203 if tags is not None :
41874204 model_package_args ["tags" ] = tags
4205+ if customer_metadata_properties is not None :
4206+ model_package_args ["customer_metadata_properties" ] = customer_metadata_properties
41884207 return model_package_args
41894208
41904209
@@ -4203,6 +4222,7 @@ def get_create_model_package_request(
42034222 description = None ,
42044223 tags = None ,
42054224 drift_check_baselines = None ,
4225+ customer_metadata_properties = None ,
42064226):
42074227 """Get request dictionary for CreateModelPackage API.
42084228
@@ -4229,6 +4249,8 @@ def get_create_model_package_request(
42294249 tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs
42304250 (default: None).
42314251 drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
4252+ customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
4253+ metadata properties (default: None).
42324254 """
42334255
42344256 if all ([model_package_name , model_package_group_name ]):
@@ -4250,6 +4272,8 @@ def get_create_model_package_request(
42504272 request_dict ["DriftCheckBaselines" ] = drift_check_baselines
42514273 if metadata_properties :
42524274 request_dict ["MetadataProperties" ] = metadata_properties
4275+ if customer_metadata_properties is not None :
4276+ request_dict ["CustomerMetadataProperties" ] = customer_metadata_properties
42534277 if containers is not None :
42544278 if not all ([content_types , response_types , inference_instances , transform_instances ]):
42554279 raise ValueError (
0 commit comments