@@ -140,6 +140,108 @@ def __init__(
140140
141141 self .model_server_workers = model_server_workers
142142
143+ def register (
144+ self ,
145+ content_types ,
146+ response_types ,
147+ inference_instances ,
148+ transform_instances ,
149+ model_package_name = None ,
150+ model_package_group_name = None ,
151+ image_uri = None ,
152+ model_metrics = None ,
153+ metadata_properties = None ,
154+ marketplace_cert = False ,
155+ approval_status = None ,
156+ description = None ,
157+ drift_check_baselines = None ,
158+ customer_metadata_properties = None ,
159+ domain = None ,
160+ sample_payload_url = None ,
161+ task = None ,
162+ framework = None ,
163+ framework_version = None ,
164+ nearest_model_name = None ,
165+ data_input_configuration = None ,
166+ ):
167+ """Creates a model package for creating SageMaker models or listing on Marketplace.
168+
169+ Args:
170+ content_types (list): The supported MIME types for the input data.
171+ response_types (list): The supported MIME types for the output data.
172+ inference_instances (list): A list of the instance types that are used to
173+ generate inferences in real-time.
174+ transform_instances (list): A list of the instance types on which a transformation
175+ job can be run or on which an endpoint can be deployed.
176+ model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
177+ using `model_package_name` makes the Model Package un-versioned (default: None).
178+ model_package_group_name (str): Model Package Group name, exclusive to
179+ `model_package_name`, using `model_package_group_name` makes the Model Package
180+ versioned (default: None).
181+ image_uri (str): Inference image uri for the container. Model class' self.image will
182+ be used if it is None (default: None).
183+ model_metrics (ModelMetrics): ModelMetrics object (default: None).
184+ metadata_properties (MetadataProperties): MetadataProperties (default: None).
185+ marketplace_cert (bool): A boolean value indicating if the Model Package is certified
186+ for AWS Marketplace (default: False).
187+ approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
188+ or "PendingManualApproval" (default: "PendingManualApproval").
189+ description (str): Model Package description (default: None).
190+ drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
191+ customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
192+ metadata properties (default: None).
193+ domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
194+ "MACHINE_LEARNING" (default: None).
195+ sample_payload_url (str): The S3 path where the sample payload is stored
196+ (default: None).
197+ task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
198+ "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
199+ "CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
200+ framework (str): Machine learning framework of the model package container image
201+ (default: None).
202+ framework_version (str): Framework version of the Model Package Container Image
203+ (default: None).
204+ nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
205+ Amazon SageMaker Inference Recommender (default: None).
206+ data_input_configuration (str): Input object for the model (default: None).
207+
208+ Returns:
209+ str: A string of SageMaker Model Package ARN.
210+ """
211+ instance_type = inference_instances [0 ]
212+ self ._init_sagemaker_session_if_does_not_exist (instance_type )
213+
214+ if image_uri :
215+ self .image_uri = image_uri
216+ if not self .image_uri :
217+ self .image_uri = self .serving_image_uri (
218+ region_name = self .sagemaker_session .boto_session .region_name ,
219+ instance_type = instance_type ,
220+ )
221+ return super (ChainerModel , self ).register (
222+ content_types ,
223+ response_types ,
224+ inference_instances ,
225+ transform_instances ,
226+ model_package_name ,
227+ model_package_group_name ,
228+ image_uri ,
229+ model_metrics ,
230+ metadata_properties ,
231+ marketplace_cert ,
232+ approval_status ,
233+ description ,
234+ drift_check_baselines = drift_check_baselines ,
235+ customer_metadata_properties = customer_metadata_properties ,
236+ domain = domain ,
237+ sample_payload_url = sample_payload_url ,
238+ task = task ,
239+ framework = (framework or self ._framework_name ).upper (),
240+ framework_version = framework_version or self .framework_version ,
241+ nearest_model_name = nearest_model_name ,
242+ data_input_configuration = data_input_configuration ,
243+ )
244+
143245 def prepare_container_def (
144246 self , instance_type = None , accelerator_type = None , serverless_inference_config = None
145247 ):
0 commit comments