diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index a99350f477..806009b0f6 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -37,6 +37,7 @@ from sagemaker.utils import to_string from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger("sagemaker") @@ -180,6 +181,7 @@ def register( skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -233,6 +235,7 @@ def register( (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -274,6 +277,7 @@ def register( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + model_life_cycle=model_life_cycle, ) def prepare_container_def( diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 2fbcf4373b..a58d701337 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1761,6 +1761,7 @@ def register( data_input_configuration=None, skip_model_validation=None, source_uri=None, + model_life_cycle=None, model_card=None, **kwargs, ): @@ -1812,6 +1813,7 @@ def register( source_uri (str): The URI of the source for the model package (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1867,6 +1869,7 @@ def register( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + model_life_cycle=model_life_cycle, ) @property diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 46319474fa..ea99be2fc0 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -36,6 +36,7 @@ from sagemaker.utils import to_string, format_tags from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger("sagemaker") @@ -362,6 +363,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -417,6 +419,7 @@ def register( (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -465,6 +468,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_life_cycle=model_life_cycle, model_card=model_card, ) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index ccafed844d..a5e9e1b6a4 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -75,6 +75,7 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker import resource_requirements from sagemaker.enums import EndpointType +from sagemaker.model_life_cycle import ModelLifeCycle def get_default_predictor( @@ -756,6 +757,7 @@ def get_register_kwargs( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, config_name: Optional[str] = None, model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None, accept_eula: Optional[bool] = None, @@ -794,6 +796,7 @@ def get_register_kwargs( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_life_cycle=model_life_cycle, model_card=model_card, accept_eula=accept_eula, ) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index e6952b2154..486079718b 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -70,6 +70,7 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties +from sagemaker.model_life_cycle import ModelLifeCycle from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements @@ -863,6 +864,7 @@ def register( source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, accept_eula: Optional[bool] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -917,6 +919,7 @@ def register( The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some models require. (Default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. """ @@ -960,6 +963,7 @@ def register( config_name=self.config_name, model_card=model_card, accept_eula=accept_eula, + model_life_cycle=model_life_cycle, ) model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 7e075e6b8a..e77c407372 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -42,6 +42,7 @@ camel_to_snake, walk_and_apply_json, ) +from sagemaker.model_life_cycle import ModelLifeCycle class JumpStartDataHolderType: @@ -2779,6 +2780,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "data_input_configuration", "skip_model_validation", "source_uri", + "model_life_cycle", "config_name", "model_card", "accept_eula", @@ -2828,6 +2830,7 @@ def __init__( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, config_name: Optional[str] = None, model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None, accept_eula: Optional[bool] = None, diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 040c6dd71f..340d35b250 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -90,6 +90,7 @@ get_add_model_package_inference_args, get_update_model_package_inference_args, ) +from sagemaker.model_life_cycle import ModelLifeCycle # Setting LOGGER for backward compatibility, in case users import it... logger = LOGGER = logging.getLogger("sagemaker") @@ -473,6 +474,7 @@ def register( skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, accept_eula: Optional[bool] = None, model_type: Optional[JumpStartModelType] = None, ): @@ -528,6 +530,7 @@ def register( (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance or pipeline step arguments @@ -597,6 +600,7 @@ def register( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + model_life_cycle=model_life_cycle, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args @@ -2385,6 +2389,23 @@ def update_source_uri( sagemaker_session = self.sagemaker_session or sagemaker.Session() sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args) + def update_model_life_cycle( + self, + model_life_cycle: ModelLifeCycle, + ): + """Modellifecycle to be set for the model package + + Args: + model_life_cycle (ModelLifeCycle): The current state of model package in its life cycle + + """ + update_model_life_cycle_args = { + "ModelPackageArn": self.model_package_arn, + "ModelLifeCycle": model_life_cycle, + } + sagemaker_session = self.sagemaker_session or sagemaker.Session() + sagemaker_session.sagemaker_client.update_model_package(**update_model_life_cycle_args) + def remove_customer_metadata_properties( self, customer_metadata_properties_to_remove: List[str] ): diff --git a/src/sagemaker/model_life_cycle.py b/src/sagemaker/model_life_cycle.py new file mode 100644 index 0000000000..59403e91c8 --- /dev/null +++ b/src/sagemaker/model_life_cycle.py @@ -0,0 +1,51 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This file contains code related to model life cycle.""" +from __future__ import absolute_import + +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + + +class ModelLifeCycle(object): + """Accepts ModelLifeCycle parameters for conversion to request dict.""" + + def __init__( + self, + stage: Optional[Union[str, PipelineVariable]] = None, + stage_status: Optional[Union[str, PipelineVariable]] = None, + stage_description: Optional[Union[str, PipelineVariable]] = None, + ): + """Initialize a ``ModelLifeCycle`` instance and turn parameters into dict. + + # TODO: flesh out docstrings + Args: + stage (str or PipelineVariable): + stage_status (str or PipelineVariable): + stage_description (str or PipelineVariable): + """ + self.stage = stage + self.stage_status = stage_status + self.stage_description = stage_description + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + model_life_cycle_request = dict() + if self.stage: + model_life_cycle_request["Stage"] = self.stage + if self.stage_status: + model_life_cycle_request["StageStatus"] = self.stage_status + if self.stage_description: + model_life_cycle_request["StageDescription"] = self.stage_description + return model_life_cycle_request diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index c220a022d6..0dcd71741d 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -39,6 +39,7 @@ from sagemaker.utils import to_string from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger("sagemaker") @@ -182,6 +183,7 @@ def register( skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -235,6 +237,7 @@ def register( (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -276,6 +279,7 @@ def register( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + model_life_cycle=model_life_cycle, ) def prepare_container_def( diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 491402e789..329f9b83b5 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -39,6 +39,7 @@ from sagemaker.utils import to_string from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger("sagemaker") @@ -184,6 +185,7 @@ def register( skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -237,6 +239,7 @@ def register( (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -278,6 +281,7 @@ def register( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + model_life_cycle=model_life_cycle, ) def prepare_container_def( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index b10a809259..bbc2c81904 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4172,6 +4172,7 @@ def create_model_package_from_containers( skip_model_validation="None", source_uri=None, model_card=None, + model_life_cycle=None, ): """Get request dictionary for CreateModelPackage API. @@ -4211,6 +4212,7 @@ def create_model_package_from_containers( source_uri (str): The URI of the source for the model package (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). """ if containers: # Containers are provided. Now we can merge missing entries from config. @@ -4269,6 +4271,7 @@ def create_model_package_from_containers( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + model_life_cycle=model_life_cycle, ) def submit(request): @@ -7196,6 +7199,7 @@ def get_model_package_args( skip_model_validation=None, source_uri=None, model_card=None, + model_life_cycle=None, ): """Get arguments for create_model_package method. @@ -7237,6 +7241,7 @@ def get_model_package_args( source_uri (str): The URI of the source for the model package (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: dict: A dictionary of method argument names and values. @@ -7293,6 +7298,8 @@ def get_model_package_args( model_package_args["skip_model_validation"] = skip_model_validation if source_uri is not None: model_package_args["source_uri"] = source_uri + if model_life_cycle is not None: + model_package_args["model_life_cycle"] = model_life_cycle if model_card is not None: original_req = model_card._create_request_args() if original_req.get("ModelCardName") is not None: @@ -7327,6 +7334,7 @@ def get_create_model_package_request( skip_model_validation="None", source_uri=None, model_card=None, + model_life_cycle=None, ): """Get request dictionary for CreateModelPackage API. @@ -7366,6 +7374,7 @@ def get_create_model_package_request( source_uri (str): The URI of the source for the model package (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). """ if all([model_package_name, model_package_group_name]): @@ -7465,7 +7474,8 @@ def get_create_model_package_request( request_dict["SkipModelValidation"] = skip_model_validation if model_card is not None: request_dict["ModelCard"] = model_card - + if model_life_cycle is not None: + request_dict["ModelLifeCycle"] = model_life_cycle return request_dict diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index a9252706c3..c3727b2fb5 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -33,6 +33,7 @@ from sagemaker.utils import to_string from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger("sagemaker") @@ -177,6 +178,7 @@ def register( skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -230,6 +232,7 @@ def register( (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -271,6 +274,7 @@ def register( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + model_life_cycle=model_life_cycle, ) def prepare_container_def( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index f2e82507e8..fe20994e20 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -32,6 +32,7 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import PipelineSession from sagemaker.utils import format_tags +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger(__name__) @@ -239,6 +240,7 @@ def register( skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -292,6 +294,7 @@ def register( (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -333,6 +336,7 @@ def register( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + model_life_cycle=model_life_cycle, ) def deploy( diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index d2221eeb7c..36c393969a 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -330,6 +330,7 @@ def __init__( skip_model_validation=None, source_uri=None, model_card=None, + model_life_cycle=None, **kwargs, ): """Constructor of a register model step. @@ -384,6 +385,7 @@ def __init__( source_uri (str): The URI of the source for the model package (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -422,6 +424,7 @@ def __init__( self.skip_model_validation = skip_model_validation self.source_uri = source_uri self.model_card = model_card + self.model_life_cycle = model_life_cycle self._properties = Properties( step_name=name, step=self, shape_name="DescribeModelPackageOutput" @@ -498,6 +501,7 @@ def arguments(self) -> RequestType: skip_model_validation=self.skip_model_validation, source_uri=self.source_uri, model_card=self.model_card, + model_life_cycle=self.model_life_cycle, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 0682716a48..a1d939254c 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -98,6 +98,7 @@ def __init__( skip_model_validation=None, source_uri=None, model_card=None, + model_life_cycle=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -158,6 +159,7 @@ def __init__( source_uri (str): The URI of the source for the model package (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). **kwargs: additional arguments to `create_model`. """ super().__init__(name=name, depends_on=depends_on) @@ -297,6 +299,7 @@ def __init__( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + model_life_cycle=model_life_cycle, **kwargs, ) if not repack_model: diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index cf6bd2826c..ea532b4c39 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -34,6 +34,7 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.xgboost.defaults import XGBOOST_NAME from sagemaker.xgboost.utils import validate_py_version, validate_framework_version +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger("sagemaker") @@ -165,6 +166,7 @@ def register( skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -218,6 +220,7 @@ def register( (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -259,6 +262,7 @@ def register( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + model_life_cycle=model_life_cycle, ) def prepare_container_def( diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index fd981b74e4..7f85c0066c 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -929,6 +929,175 @@ def test_model_registration_with_model_card_object( pass +def test_model_registration_with_model_life_cycle_object( + sagemaker_session_for_pipeline, + role, + pipeline_name, +): + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + instance_type = "ml.m5.xlarge" + + # upload model data to s3 + model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz") + model_base_uri = "s3://{}/{}/input/model/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("model"), + ) + model_uri = S3Uploader.upload( + model_local_path, model_base_uri, sagemaker_session=sagemaker_session_for_pipeline + ) + model_uri_param = ParameterString(name="model_uri", default_value=model_uri) + + # upload metrics to s3 + metrics_data = ( + '{"regression_metrics": {"mse": {"value": 4.925353410353891, ' + '"standard_deviation": 2.219186917819692}}}' + ) + metrics_base_uri = "s3://{}/{}/input/metrics/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("metrics"), + ) + metrics_uri = S3Uploader.upload_string_as_file_body( + body=metrics_data, + desired_s3_uri=metrics_base_uri, + sagemaker_session=sagemaker_session_for_pipeline, + ) + metrics_uri_param = ParameterString(name="metrics_uri", default_value=metrics_uri) + + model_metrics = ModelMetrics( + bias=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + explainability=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_pre_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_post_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + ) + customer_metadata_properties = {"key1": "value1"} + domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + skip_model_validation = "All" + + # If image_uri is not provided, the instance_type should not be a pipeline variable + # since instance_type is used to retrieve image_uri in compile time (PySDK) + estimator = XGBoost( + entry_point="training.py", + source_dir=os.path.join(DATA_DIR, "sip"), + instance_type=instance_type, + instance_count=instance_count, + framework_version="0.90-2", + sagemaker_session=sagemaker_session_for_pipeline, + py_version="py3", + role=role, + ) + create_model_life_cycle = { + "Stage": "Development", + "StageStatus": "In-Progress", + "StageDescription": "Development In Progress", + } + + step_register = RegisterModel( + name="MyRegisterModelStep", + estimator=estimator, + model_data=model_uri_param, + content_types=["application/json"], + response_types=["application/json"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_group_name="testModelPackageGroup", + model_metrics=model_metrics, + customer_metadata_properties=customer_metadata_properties, + domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, + model_life_cycle=create_model_life_cycle, + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[ + model_uri_param, + metrics_uri_param, + instance_count, + ], + steps=[step_register], + sagemaker_session=sagemaker_session_for_pipeline, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + + for _ in retries( + max_retry_count=5, + exception_message_prefix="Waiting for a successful execution of pipeline", + seconds_to_sleep=10, + ): + execution = pipeline.start( + parameters={"model_uri": model_uri, "metrics_uri": metrics_uri} + ) + response = execution.describe() + + assert response["PipelineArn"] == create_arn + + wait_pipeline_execution(execution=execution) + execution_steps = execution.list_steps() + + assert len(execution_steps) == 1 + failure_reason = execution_steps[0].get("FailureReason", "") + if failure_reason != "": + logging.error( + f"Pipeline execution failed with error: {failure_reason}." " Retrying.." + ) + continue + assert execution_steps[0]["StepStatus"] == "Succeeded" + assert execution_steps[0]["StepName"] == "MyRegisterModelStep-RegisterModel" + + response = sagemaker_session_for_pipeline.sagemaker_client.describe_model_package( + ModelPackageName=execution_steps[0]["Metadata"]["RegisterModel"]["Arn"] + ) + + assert ( + response["ModelMetrics"]["Explainability"]["Report"]["ContentType"] + == "application/json" + ) + assert response["CustomerMetadataProperties"] == customer_metadata_properties + assert response["Domain"] == domain + assert response["Task"] == task + assert response["SamplePayloadUrl"] == sample_payload_url + assert response["SkipModelValidation"] == skip_model_validation + assert (response["ModelLifeCycle"]["Stage"]) == "Development" + assert (response["ModelLifeCycle"]["StageStatus"]) == "In-Progress" + assert (response["ModelLifeCycle"]["StageDescription"]) == "Development In Progress" + break + finally: + try: + pipeline.delete() + except Exception: + pass + + def test_model_registration_with_model_card_json( sagemaker_session_for_pipeline, role, diff --git a/tests/integ/test_model_package.py b/tests/integ/test_model_package.py index f59901ee61..bc8120bd07 100644 --- a/tests/integ/test_model_package.py +++ b/tests/integ/test_model_package.py @@ -29,6 +29,7 @@ from sagemaker import image_uris from sagemaker.session import get_execution_role from sagemaker.model import ModelPackage +from sagemaker.model_life_cycle import ModelLifeCycle _XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone") @@ -75,6 +76,62 @@ def test_update_approval_model_package(sagemaker_session): ) +def test_update_model_life_cycle_model_package(sagemaker_session): + + model_group_name = unique_name_from_base("test-model-group") + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + create_model_life_cycle = ModelLifeCycle( + stage="Development", + stage_status="In-Progress", + stage_description="Development In Progress", + ) + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + model_life_cycle=create_model_life_cycle._to_request_dict(), + ) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + create_model_life_cycle_req = create_model_life_cycle._to_request_dict() + + assert desc_model_package["ModelLifeCycle"] == create_model_life_cycle_req + + update_model_life_cycle = ModelLifeCycle( + stage="Staging", + stage_status="In-Progress", + stage_description="Sending for Staging Verification", + ) + update_model_life_cycle_req = update_model_life_cycle._to_request_dict() + + model_package.update_model_life_cycle(update_model_life_cycle_req) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + assert desc_model_package["ModelLifeCycle"] == update_model_life_cycle_req + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package.model_package_arn + ) + + def test_inference_specification_addition(sagemaker_session): model_group_name = unique_name_from_base("test-model-group") diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index 062ffaf2ed..85649a8d24 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -21,6 +21,7 @@ from sagemaker.model import ModelPackage from sagemaker.model_card.model_card import ModelCard, ModelOverview from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum, ModelCardStatusEnum +from sagemaker.model_life_cycle import ModelLifeCycle MODEL_PACKAGE_VERSIONED_ARN = ( "arn:aws:sagemaker:us-west-2:001234567890:model-package/testmodelgroup/1" @@ -492,3 +493,40 @@ def test_update_model_card(sagemaker_session): sagemaker_session.sagemaker_client.update_model_package.assert_called_with( ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelCard=update_my_card_req_1 ) + + +def test_update_model_life_cycle(sagemaker_session): + model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) + + sagemaker_session.sagemaker_client.describe_model_package = Mock( + return_value=model_package_response + ) + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + + update_model_life_cycle = ModelLifeCycle( + stage="Development", + stage_status="Approved", + stage_description="Approving for Development", + ) + update_model_life_cycle_req = update_model_life_cycle._to_request_dict() + model_package.update_model_life_cycle(update_model_life_cycle_req) + + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelLifeCycle=update_model_life_cycle_req + ) + + update_model_life_cycle1 = ModelLifeCycle( + stage="Staging", + stage_status="In-Progress", + stage_description="Sending for Staging Verification", + ) + update_model_life_cycle_req1 = update_model_life_cycle1._to_request_dict() + model_package.update_model_life_cycle(update_model_life_cycle_req1) + + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelLifeCycle=update_model_life_cycle_req1 + ) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 31afaa0e7e..0bc84d29d0 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -74,6 +74,7 @@ DEFAULT_S3_BUCKET_NAME, DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, ) +from sagemaker.model_life_cycle import ModelLifeCycle MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -4345,6 +4346,12 @@ def test_register_default_image(sagemaker_session): status=ModelCardStatusEnum.DRAFT, model_overview=model_overview, ) + update_model_life_cycle = ModelLifeCycle( + stage="Development", + stage_status="In-Progress", + stage_description="Sending for Staging Verification", + ) + update_model_life_cycle_req = update_model_life_cycle._to_request_dict() estimator.register( content_types=content_types, @@ -4359,12 +4366,19 @@ def test_register_default_image(sagemaker_session): nearest_model_name=nearest_model_name, data_input_configuration=data_input_config, model_card=model_card, + model_life_cycle=update_model_life_cycle_req, ) sagemaker_session.create_model.assert_not_called() exp_model_card = { "ModelCardStatus": "Draft", "ModelCardContent": '{"model_overview": {"model_creator": "TestCreator", "model_artifact": []}}', } + exp_model_life_cycle = { + "Stage": "Development", + "StageStatus": "In-Progress", + "StageDescription": "Sending for Staging Verification", + } + expected_create_model_package_request = { "containers": [{"Image": estimator.image_uri, "ModelDataUrl": estimator.model_data}], "content_types": content_types, @@ -4375,6 +4389,7 @@ def test_register_default_image(sagemaker_session): "marketplace_cert": False, "sample_payload_url": sample_payload_url, "task": task, + "model_life_cycle": exp_model_life_cycle, "model_card": exp_model_card, } sagemaker_session.create_model_package_from_containers.assert_called_with( diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index c776dfe479..d2d2c3bcfb 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -5360,6 +5360,11 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): }, }, } + model_life_cycle = { + "Stage": "Development", + "StageStatus": "In-Progress", + "StageDescription": "Sending for Staging Verification", + } sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5379,6 +5384,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): task=task, skip_model_validation=skip_model_validation, model_card=model_card, + model_life_cycle=model_life_cycle, ) expected_args = { "ModelPackageName": model_package_name, @@ -5401,6 +5407,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "Task": task, "SkipModelValidation": skip_model_validation, "ModelCard": model_card, + "ModelLifeCycle": model_life_cycle, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args)