diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 3d5c645c73..b1d58eceb9 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -169,6 +169,7 @@ def register( framework_version: Optional[Union[str, PipelineVariable]] = None, nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -216,6 +217,8 @@ def register( benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str or PipelineVariable): Input object for the model (default: None). + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model + validation. Values can be "All" or "None" (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -254,6 +257,7 @@ def register( framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, ) def prepare_container_def( diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 38fd60d4e3..b757994fd5 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1684,6 +1684,7 @@ def register( framework_version=None, nearest_model_name=None, data_input_configuration=None, + skip_model_validation=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1729,6 +1730,8 @@ def register( nearest_model_name (str): Name of a pre-trained machine learning benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str): Input object for the model (default: None). + skip_model_validation (str): Indicates if you want to skip model validation. + Values can be "All" or "None" (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1772,6 +1775,7 @@ def register( framework_version=framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, ) @property diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 74eeeb2546..7f733bfd11 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -353,6 +353,7 @@ def register( framework_version: Optional[Union[str, PipelineVariable]] = None, nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -401,6 +402,8 @@ def register( benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str or PipelineVariable): Input object for the model (default: None). + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model + validation. Values can be "All" or "None" (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -447,6 +450,7 @@ def register( ], nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, ) def prepare_container_def( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index d0f833795c..5dea4cdd29 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -400,6 +400,7 @@ def register( framework_version: Optional[Union[str, PipelineVariable]] = None, nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -447,6 +448,8 @@ def register( benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str or PipelineVariable): Input object for the model (default: None). + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model + validation. Values can be "All" or "None" (default: None). Returns: A `sagemaker.model.ModelPackage` instance or pipeline step arguments @@ -497,6 +500,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, + skip_model_validation=skip_model_validation, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 6983fa64b9..199adead2a 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -171,6 +171,7 @@ def register( framework_version: Optional[Union[str, PipelineVariable]] = None, nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -218,6 +219,8 @@ def register( benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str or PipelineVariable): Input object for the model (default: None). + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model + validation. Values can be "All" or "None" (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -256,6 +259,7 @@ def register( framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, ) def prepare_container_def( diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 8c209d1280..d48f3b9cc5 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -356,6 +356,7 @@ def register( framework_version: Optional[Union[str, PipelineVariable]] = None, nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -403,6 +404,8 @@ def register( benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str or PipelineVariable): Input object for the model (default: None). + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model + validation. Values can be "All" or "None" (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -448,6 +451,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, + skip_model_validation=skip_model_validation, ) self.sagemaker_session.create_model_package_from_containers(**model_pkg_args) diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index ce3181688f..09b9ee0c68 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -173,6 +173,7 @@ def register( framework_version: Optional[Union[str, PipelineVariable]] = None, nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -220,6 +221,8 @@ def register( benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str or PipelineVariable): Input object for the model (default: None). + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model + validation. Values can be "All" or "None" (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -258,6 +261,7 @@ def register( framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, ) def prepare_container_def( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 6c4496a2af..3cc677d2dc 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3648,6 +3648,7 @@ def create_model_package_from_containers( domain=None, sample_payload_url=None, task=None, + skip_model_validation="None", ): """Get request dictionary for CreateModelPackage API. @@ -3682,6 +3683,8 @@ def create_model_package_from_containers( task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + skip_model_validation (str): Indicates if you want to skip model validation. + Values can be "All" or "None" (default: None). """ if containers: # Containers are provided. Now we can merge missing entries from config. @@ -3737,6 +3740,7 @@ def create_model_package_from_containers( domain=domain, sample_payload_url=sample_payload_url, task=task, + skip_model_validation=skip_model_validation, ) def submit(request): @@ -5764,6 +5768,7 @@ def get_model_package_args( domain=None, sample_payload_url=None, task=None, + skip_model_validation=None, ): """Get arguments for create_model_package method. @@ -5800,6 +5805,8 @@ def get_model_package_args( task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + skip_model_validation (str): Indicates if you want to skip model validation. + Values can be "All" or "None" (default: None). Returns: dict: A dictionary of method argument names and values. @@ -5848,6 +5855,8 @@ def get_model_package_args( model_package_args["sample_payload_url"] = sample_payload_url if task is not None: model_package_args["task"] = task + if skip_model_validation is not None: + model_package_args["skip_model_validation"] = skip_model_validation return model_package_args @@ -5871,6 +5880,7 @@ def get_create_model_package_request( domain=None, sample_payload_url=None, task=None, + skip_model_validation="None", ): """Get request dictionary for CreateModelPackage API. @@ -5905,6 +5915,8 @@ def get_create_model_package_request( task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + skip_model_validation (str): Indicates if you want to skip model validation. + Values can be "All" or "None" (default: None). """ if all([model_package_name, model_package_group_name]): @@ -5974,6 +5986,7 @@ def get_create_model_package_request( request_dict["InferenceSpecification"] = inference_specification request_dict["CertifyForMarketplace"] = marketplace_cert request_dict["ModelApprovalStatus"] = approval_status + request_dict["SkipModelValidation"] = skip_model_validation return request_dict diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 46425c0660..9b1ea73090 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -166,6 +166,7 @@ def register( framework_version: Optional[Union[str, PipelineVariable]] = None, nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -213,6 +214,8 @@ def register( benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str or PipelineVariable): Input object for the model (default: None). + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model + validation. Values can be "All" or "None" (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -251,6 +254,7 @@ def register( framework_version=framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, ) def prepare_container_def( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 705b7667dd..97b0e51b01 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -228,6 +228,7 @@ def register( framework_version: Optional[Union[str, PipelineVariable]] = None, nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -275,6 +276,8 @@ def register( benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str or PipelineVariable): Input object for the model (default: None). + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model + validation. Values can be "All" or "None" (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -313,6 +316,7 @@ def register( framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, ) def deploy( diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index b7372237e3..8b30823d06 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -322,6 +322,7 @@ def __init__( domain=None, sample_payload_url=None, task=None, + skip_model_validation=None, **kwargs, ): """Constructor of a register model step. @@ -371,6 +372,8 @@ def __init__( task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + skip_model_validation (str): Indicates if you want to skip model validation. + Values can be "All" or "None" (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -407,6 +410,7 @@ def __init__( self.tags = tags self.kwargs = kwargs self.container_def_list = container_def_list + self.skip_model_validation = skip_model_validation self._properties = Properties(step_name=name, shape_name="DescribeModelPackageOutput") @@ -481,6 +485,7 @@ def arguments(self) -> RequestType: domain=self.domain, sample_payload_url=self.sample_payload_url, task=self.task, + skip_model_validation=self.skip_model_validation, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 8b6da2ae93..405f910639 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -90,6 +90,7 @@ def __init__( framework_version=None, nearest_model_name=None, data_input_configuration=None, + skip_model_validation=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -145,6 +146,8 @@ def __init__( nearest_model_name (str): Name of a pre-trained machine learning benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str): Input object for the model (default: None). + skip_model_validation (str): Indicates if you want to skip model + validation. Values can be "All" or "None" (default: None). **kwargs: additional arguments to `create_model`. """ @@ -281,6 +284,7 @@ def __init__( domain=domain, sample_payload_url=sample_payload_url, task=task, + skip_model_validation=skip_model_validation, **kwargs, ) if not repack_model: diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index e2be025b62..a40151c925 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -154,6 +154,7 @@ def register( framework_version: Optional[Union[str, PipelineVariable]] = None, nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -201,6 +202,8 @@ def register( benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str or PipelineVariable): Input object for the model (default: None). + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model + validation. Values can be "All" or "None" (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -239,6 +242,7 @@ def register( framework_version=framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, ) def prepare_container_def( diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index 98ee4bfde6..75e086bcfe 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -600,6 +600,7 @@ def test_model_registration_with_drift_check_baselines( framework_version = "2.9" nearest_model_name = "resnet50" data_input_configuration = '{"input_1":[1,224,224,3]}' + skip_model_validation = "All" # If image_uri is not provided, the instance_type should not be a pipeline variable # since instance_type is used to retrieve image_uri in compile time (PySDK) @@ -633,6 +634,7 @@ def test_model_registration_with_drift_check_baselines( framework_version=framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, ) pipeline = Pipeline( @@ -703,6 +705,7 @@ def test_model_registration_with_drift_check_baselines( assert response["Domain"] == domain assert response["Task"] == task assert response["SamplePayloadUrl"] == sample_payload_url + assert response["SkipModelValidation"] == skip_model_validation break finally: try: diff --git a/tests/unit/sagemaker/workflow/test_pipeline_session.py b/tests/unit/sagemaker/workflow/test_pipeline_session.py index b6c17033ed..d05bb5c20f 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_session.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_session.py @@ -218,6 +218,7 @@ def test_pipeline_session_context_for_model_step_without_instance_types( }, "CertifyForMarketplace": False, "ModelApprovalStatus": "PendingManualApproval", + "SkipModelValidation": "None", "SamplePayloadUrl": "s3://test-bucket/model", "Task": "IMAGE_CLASSIFICATION", } @@ -280,6 +281,7 @@ def test_pipeline_session_context_for_model_step_with_one_instance_types( }, "CertifyForMarketplace": False, "ModelApprovalStatus": "PendingManualApproval", + "SkipModelValidation": "None", "SamplePayloadUrl": "s3://test-bucket/model", "Task": "IMAGE_CLASSIFICATION", } diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index ddbd062f17..847a2d3656 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -498,6 +498,7 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): "SupportedTransformInstanceTypes": ["transform_instance"], }, "ModelApprovalStatus": "Approved", + "SkipModelValidation": "None", "ModelMetrics": { "Bias": {}, "Explainability": {}, @@ -568,6 +569,7 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): "SupportedTransformInstanceTypes": ["transform_instance"], }, "ModelApprovalStatus": "Approved", + "SkipModelValidation": "None", "ModelMetrics": { "Bias": {}, "Explainability": {}, @@ -657,6 +659,7 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): "SupportedTransformInstanceTypes": ["transform_instance"], }, "ModelApprovalStatus": "Approved", + "SkipModelValidation": "None", "ModelMetrics": { "Bias": {}, "Explainability": {}, @@ -794,6 +797,7 @@ def test_register_model_with_model_repack_with_estimator( "SupportedTransformInstanceTypes": ["transform_instance"], }, "ModelApprovalStatus": "Approved", + "SkipModelValidation": "None", "ModelMetrics": { "Bias": {}, "Explainability": {}, @@ -919,6 +923,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift "SupportedTransformInstanceTypes": ["transform_instance"], }, "ModelApprovalStatus": "Approved", + "SkipModelValidation": "None", "ModelMetrics": { "Bias": {}, "Explainability": {}, @@ -1052,6 +1057,7 @@ def test_register_model_with_model_repack_with_pipeline_model( "SupportedTransformInstanceTypes": ["transform_instance"], }, "ModelApprovalStatus": "Approved", + "SkipModelValidation": "None", "ModelMetrics": { "Bias": {}, "Explainability": {}, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index b07f90a55b..cf35c8e0fe 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -4621,6 +4621,7 @@ def test_create_model_package_from_containers_without_model_package_group_name( def test_create_model_package_with_sagemaker_config_injection(sagemaker_session): sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_MODEL_PACKAGE + skip_model_validation = "All" model_package_name = "sagemaker-model-package" containers = [{"Image": "dummy-container"}] content_types = ["application/json"] @@ -4680,6 +4681,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) sample_payload_url=sample_payload_url, task=task, validation_specification=validation_specification, + skip_model_validation=skip_model_validation, ) expected_kms_key_id = SAGEMAKER_CONFIG_MODEL_PACKAGE["SageMaker"]["ModelPackage"][ "ValidationSpecification" @@ -4717,6 +4719,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) "SamplePayloadUrl": sample_payload_url, "Task": task, "ValidationSpecification": validation_specification, + "SkipModelValidation": skip_model_validation, } ) expected_args["ValidationSpecification"]["ValidationRole"] = expected_role_arn @@ -4764,6 +4767,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "GeneratedBy": "sagemaker-python-sdk", "ProjectId": "unit-test", } + skip_model_validation = "All" marketplace_cert = (True,) approval_status = ("Approved",) description = "description" @@ -4788,6 +4792,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): domain=domain, sample_payload_url=sample_payload_url, task=task, + skip_model_validation=skip_model_validation, ) expected_args = { "ModelPackageName": model_package_name, @@ -4808,6 +4813,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "Domain": domain, "SamplePayloadUrl": sample_payload_url, "Task": task, + "SkipModelValidation": skip_model_validation, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) @@ -4838,6 +4844,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s "GeneratedBy": "sagemaker-python-sdk", "ProjectId": "unit-test", } + skip_model_validation = "All" marketplace_cert = (True,) approval_status = ("Approved",) description = "description" @@ -4854,6 +4861,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s description=description, drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, + skip_model_validation=skip_model_validation, ) expected_args = { "ModelPackageGroupName": model_package_group_name, @@ -4869,6 +4877,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s "ModelApprovalStatus": approval_status, "DriftCheckBaselines": drift_check_baselines, "CustomerMetadataProperties": customer_metadata_properties, + "SkipModelValidation": skip_model_validation, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) @@ -4902,6 +4911,7 @@ def test_create_model_package_from_containers_with_one_instance_types( "GeneratedBy": "sagemaker-python-sdk", "ProjectId": "unit-test", } + skip_model_validation = "All" marketplace_cert = (True,) approval_status = ("Approved",) description = "description" @@ -4919,6 +4929,7 @@ def test_create_model_package_from_containers_with_one_instance_types( description=description, drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, + skip_model_validation=skip_model_validation, ) expected_args = { "ModelPackageGroupName": model_package_group_name, @@ -4935,6 +4946,7 @@ def test_create_model_package_from_containers_with_one_instance_types( "ModelApprovalStatus": approval_status, "DriftCheckBaselines": drift_check_baselines, "CustomerMetadataProperties": customer_metadata_properties, + "SkipModelValidation": skip_model_validation, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args)