3535from sagemaker .serverless import ServerlessInferenceConfig
3636from sagemaker .transformer import Transformer
3737from sagemaker .jumpstart .utils import add_jumpstart_tags , get_jumpstart_base_name_if_jumpstart_model
38- from sagemaker .utils import unique_name_from_base
38+ from sagemaker .utils import (
39+ unique_name_from_base ,
40+ update_container_with_inference_params ,
41+ )
3942from sagemaker .async_inference import AsyncInferenceConfig
4043from sagemaker .predictor_async import AsyncPredictor
4144from sagemaker .workflow import is_pipeline_variable
@@ -310,6 +313,12 @@ def register(
310313 customer_metadata_properties = None ,
311314 validation_specification = None ,
312315 domain = None ,
316+ task = None ,
317+ sample_payload_url = None ,
318+ framework = None ,
319+ framework_version = None ,
320+ nearest_model_name = None ,
321+ data_input_configuration = None ,
313322 ):
314323 """Creates a model package for creating SageMaker models or listing on Marketplace.
315324
@@ -339,6 +348,18 @@ def register(
339348 metadata properties (default: None).
340349 domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
341350 "MACHINE_LEARNING" (default: None).
351+ sample_payload_url (str): The S3 path where the sample payload is stored
352+ (default: None).
353+ task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
354+ "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
355+ "CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
356+ framework (str): Machine learning framework of the model package container image
357+ (default: None).
358+ framework_version (str): Framework version of the Model Package Container Image
359+ (default: None).
360+ nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
361+ Amazon SageMaker Inference Recommender (default: None).
362+ data_input_configuration (str): Input object for the model (default: None).
342363
343364 Returns:
344365 A `sagemaker.model.ModelPackage` instance or pipeline step arguments
@@ -349,10 +370,22 @@ def register(
349370 raise ValueError ("SageMaker Model Package cannot be created without model data." )
350371 if image_uri is not None :
351372 self .image_uri = image_uri
373+
352374 if model_package_group_name is not None :
353375 container_def = self .prepare_container_def ()
376+ update_container_with_inference_params (
377+ framework = framework ,
378+ framework_version = framework_version ,
379+ nearest_model_name = nearest_model_name ,
380+ data_input_configuration = data_input_configuration ,
381+ container_obj = container_def ,
382+ )
354383 else :
355- container_def = {"Image" : self .image_uri , "ModelDataUrl" : self .model_data }
384+ container_def = {
385+ "Image" : self .image_uri ,
386+ "ModelDataUrl" : self .model_data ,
387+ }
388+
356389 model_pkg_args = sagemaker .get_model_package_args (
357390 content_types ,
358391 response_types ,
@@ -370,6 +403,8 @@ def register(
370403 customer_metadata_properties = customer_metadata_properties ,
371404 validation_specification = validation_specification ,
372405 domain = domain ,
406+ sample_payload_url = sample_payload_url ,
407+ task = task ,
373408 )
374409 model_package = self .sagemaker_session .create_model_package_from_containers (
375410 ** model_pkg_args
0 commit comments