@@ -2521,8 +2521,11 @@ def _create_model_request(
25212521 request = {"ModelName" : name , "ExecutionRoleArn" : role }
25222522 if isinstance (container_definition , list ):
25232523 request ["Containers" ] = container_definition
2524+ elif "ModelPackageName" in container_definition :
2525+ request ["Containers" ] = [container_definition ]
25242526 else :
25252527 request ["PrimaryContainer" ] = container_definition
2528+
25262529 if tags :
25272530 request ["Tags" ] = tags
25282531
@@ -2731,7 +2734,7 @@ def create_model_package_from_containers(
27312734 description (str): Model Package description (default: None).
27322735 """
27332736
2734- request = self . _get_create_model_package_request (
2737+ request = get_create_model_package_request (
27352738 model_package_name ,
27362739 model_package_group_name ,
27372740 containers ,
@@ -2747,82 +2750,6 @@ def create_model_package_from_containers(
27472750 )
27482751 return self .sagemaker_client .create_model_package (** request )
27492752
2750- def _get_create_model_package_request (
2751- self ,
2752- model_package_name = None ,
2753- model_package_group_name = None ,
2754- containers = None ,
2755- content_types = None ,
2756- response_types = None ,
2757- inference_instances = None ,
2758- transform_instances = None ,
2759- model_metrics = None ,
2760- metadata_properties = None ,
2761- marketplace_cert = False ,
2762- approval_status = "PendingManualApproval" ,
2763- description = None ,
2764- tags = None ,
2765- ):
2766- """Get request dictionary for CreateModelPackage API.
2767-
2768- Args:
2769- model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
2770- using `model_package_name` makes the Model Package un-versioned (default: None).
2771- model_package_group_name (str): Model Package Group name, exclusive to
2772- `model_package_name`, using `model_package_group_name` makes the Model Package
2773- versioned (default: None).
2774- containers (list): A list of inference containers that can be used for inference
2775- specifications of Model Package (default: None).
2776- content_types (list): The supported MIME types for the input data (default: None).
2777- response_types (list): The supported MIME types for the output data (default: None).
2778- inference_instances (list): A list of the instance types that are used to
2779- generate inferences in real-time (default: None).
2780- transform_instances (list): A list of the instance types on which a transformation
2781- job can be run or on which an endpoint can be deployed (default: None).
2782- model_metrics (ModelMetrics): ModelMetrics object (default: None).
2783- metadata_properties (MetadataProperties): MetadataProperties object (default: None).
2784- marketplace_cert (bool): A boolean value indicating if the Model Package is certified
2785- for AWS Marketplace (default: False).
2786- approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
2787- or "PendingManualApproval" (default: "PendingManualApproval").
2788- description (str): Model Package description (default: None).
2789- """
2790- if all ([model_package_name , model_package_group_name ]):
2791- raise ValueError (
2792- "model_package_name and model_package_group_name cannot be present at the "
2793- "same time."
2794- )
2795- request_dict = {}
2796- if model_package_name is not None :
2797- request_dict ["ModelPackageName" ] = model_package_name
2798- if model_package_group_name is not None :
2799- request_dict ["ModelPackageGroupName" ] = model_package_group_name
2800- if description is not None :
2801- request_dict ["ModelPackageDescription" ] = description
2802- if tags is not None :
2803- request_dict ["Tags" ] = tags
2804- if model_metrics :
2805- request_dict ["ModelMetrics" ] = model_metrics
2806- if metadata_properties :
2807- request_dict ["MetadataProperties" ] = metadata_properties
2808- if containers is not None :
2809- if not all ([content_types , response_types , inference_instances , transform_instances ]):
2810- raise ValueError (
2811- "content_types, response_types, inference_inferences and transform_instances "
2812- "must be provided if containers is present."
2813- )
2814- inference_specification = {
2815- "Containers" : containers ,
2816- "SupportedContentTypes" : content_types ,
2817- "SupportedResponseMIMETypes" : response_types ,
2818- "SupportedRealtimeInferenceInstanceTypes" : inference_instances ,
2819- "SupportedTransformInstanceTypes" : transform_instances ,
2820- }
2821- request_dict ["InferenceSpecification" ] = inference_specification
2822- request_dict ["CertifyForMarketplace" ] = marketplace_cert
2823- request_dict ["ModelApprovalStatus" ] = approval_status
2824- return request_dict
2825-
28262753 def wait_for_model_package (self , model_package_name , poll = 5 ):
28272754 """Wait for an Amazon SageMaker endpoint deployment to complete.
28282755
@@ -4097,6 +4024,160 @@ def account_id(self) -> str:
40974024 return sts_client .get_caller_identity ()["Account" ]
40984025
40994026
4027+ def get_model_package_args (
4028+ content_types ,
4029+ response_types ,
4030+ inference_instances ,
4031+ transform_instances ,
4032+ model_package_name = None ,
4033+ model_package_group_name = None ,
4034+ model_data = None ,
4035+ image_uri = None ,
4036+ model_metrics = None ,
4037+ metadata_properties = None ,
4038+ marketplace_cert = False ,
4039+ approval_status = None ,
4040+ description = None ,
4041+ tags = None ,
4042+ container_def_list = None ,
4043+ ):
4044+ """Get arguments for create_model_package method.
4045+
4046+ Args:
4047+ content_types (list): The supported MIME types for the input data.
4048+ response_types (list): The supported MIME types for the output data.
4049+ inference_instances (list): A list of the instance types that are used to
4050+ generate inferences in real-time.
4051+ transform_instances (list): A list of the instance types on which a transformation
4052+ job can be run or on which an endpoint can be deployed.
4053+ model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
4054+ using `model_package_name` makes the Model Package un-versioned (default: None).
4055+ model_package_group_name (str): Model Package Group name, exclusive to
4056+ `model_package_name`, using `model_package_group_name` makes the Model Package
4057+ versioned (default: None).
4058+ image_uri (str): Inference image uri for the container. Model class' self.image will
4059+ be used if it is None (default: None).
4060+ model_metrics (ModelMetrics): ModelMetrics object (default: None).
4061+ metadata_properties (MetadataProperties): MetadataProperties object (default: None).
4062+ marketplace_cert (bool): A boolean value indicating if the Model Package is certified
4063+ for AWS Marketplace (default: False).
4064+ approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
4065+ or "PendingManualApproval" (default: "PendingManualApproval").
4066+ description (str): Model Package description (default: None).
4067+ container_def_list (list): A list of container defintiions.
4068+ Returns:
4069+ dict: A dictionary of method argument names and values.
4070+ """
4071+ if container_def_list is not None :
4072+ containers = container_def_list
4073+ else :
4074+ container = {
4075+ "Image" : image_uri ,
4076+ "ModelDataUrl" : model_data ,
4077+ }
4078+ containers = [container ]
4079+
4080+ model_package_args = {
4081+ "containers" : containers ,
4082+ "content_types" : content_types ,
4083+ "response_types" : response_types ,
4084+ "inference_instances" : inference_instances ,
4085+ "transform_instances" : transform_instances ,
4086+ "marketplace_cert" : marketplace_cert ,
4087+ }
4088+
4089+ if model_package_name is not None :
4090+ model_package_args ["model_package_name" ] = model_package_name
4091+ if model_package_group_name is not None :
4092+ model_package_args ["model_package_group_name" ] = model_package_group_name
4093+ if model_metrics is not None :
4094+ model_package_args ["model_metrics" ] = model_metrics ._to_request_dict ()
4095+ if metadata_properties is not None :
4096+ model_package_args ["metadata_properties" ] = metadata_properties ._to_request_dict ()
4097+ if approval_status is not None :
4098+ model_package_args ["approval_status" ] = approval_status
4099+ if description is not None :
4100+ model_package_args ["description" ] = description
4101+ if tags is not None :
4102+ model_package_args ["tags" ] = tags
4103+ return model_package_args
4104+
4105+
4106+ def get_create_model_package_request (
4107+ model_package_name = None ,
4108+ model_package_group_name = None ,
4109+ containers = None ,
4110+ content_types = None ,
4111+ response_types = None ,
4112+ inference_instances = None ,
4113+ transform_instances = None ,
4114+ model_metrics = None ,
4115+ metadata_properties = None ,
4116+ marketplace_cert = False ,
4117+ approval_status = "PendingManualApproval" ,
4118+ description = None ,
4119+ tags = None ,
4120+ ):
4121+ """Get request dictionary for CreateModelPackage API.
4122+
4123+ Args:
4124+ model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
4125+ using `model_package_name` makes the Model Package un-versioned (default: None).
4126+ model_package_group_name (str): Model Package Group name, exclusive to
4127+ `model_package_name`, using `model_package_group_name` makes the Model Package
4128+ versioned (default: None).
4129+ containers (list): A list of inference containers that can be used for inference
4130+ specifications of Model Package (default: None).
4131+ content_types (list): The supported MIME types for the input data (default: None).
4132+ response_types (list): The supported MIME types for the output data (default: None).
4133+ inference_instances (list): A list of the instance types that are used to
4134+ generate inferences in real-time (default: None).
4135+ transform_instances (list): A list of the instance types on which a transformation
4136+ job can be run or on which an endpoint can be deployed (default: None).
4137+ model_metrics (ModelMetrics): ModelMetrics object (default: None).
4138+ metadata_properties (MetadataProperties): MetadataProperties object (default: None).
4139+ marketplace_cert (bool): A boolean value indicating if the Model Package is certified
4140+ for AWS Marketplace (default: False).
4141+ approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
4142+ or "PendingManualApproval" (default: "PendingManualApproval").
4143+ description (str): Model Package description (default: None).
4144+ """
4145+ if all ([model_package_name , model_package_group_name ]):
4146+ raise ValueError (
4147+ "model_package_name and model_package_group_name cannot be present at the " "same time."
4148+ )
4149+ request_dict = {}
4150+ if model_package_name is not None :
4151+ request_dict ["ModelPackageName" ] = model_package_name
4152+ if model_package_group_name is not None :
4153+ request_dict ["ModelPackageGroupName" ] = model_package_group_name
4154+ if description is not None :
4155+ request_dict ["ModelPackageDescription" ] = description
4156+ if tags is not None :
4157+ request_dict ["Tags" ] = tags
4158+ if model_metrics :
4159+ request_dict ["ModelMetrics" ] = model_metrics
4160+ if metadata_properties :
4161+ request_dict ["MetadataProperties" ] = metadata_properties
4162+ if containers is not None :
4163+ if not all ([content_types , response_types , inference_instances , transform_instances ]):
4164+ raise ValueError (
4165+ "content_types, response_types, inference_inferences and transform_instances "
4166+ "must be provided if containers is present."
4167+ )
4168+ inference_specification = {
4169+ "Containers" : containers ,
4170+ "SupportedContentTypes" : content_types ,
4171+ "SupportedResponseMIMETypes" : response_types ,
4172+ "SupportedRealtimeInferenceInstanceTypes" : inference_instances ,
4173+ "SupportedTransformInstanceTypes" : transform_instances ,
4174+ }
4175+ request_dict ["InferenceSpecification" ] = inference_specification
4176+ request_dict ["CertifyForMarketplace" ] = marketplace_cert
4177+ request_dict ["ModelApprovalStatus" ] = approval_status
4178+ return request_dict
4179+
4180+
41004181def update_args (args : Dict [str , Any ], ** kwargs ):
41014182 """Updates the request arguments dict with the value if populated.
41024183
0 commit comments