Skip to content

Feature: Added support of ModelLifeCycle construct in Model Registry #4919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from sagemaker.utils import to_string
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.model_life_cycle import ModelLifeCycle

logger = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -180,6 +181,7 @@ def register(
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -233,6 +235,7 @@ def register(
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).

Returns:
str: A string of SageMaker Model Package ARN.
Expand Down Expand Up @@ -274,6 +277,7 @@ def register(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
)

def prepare_container_def(
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,6 +1761,7 @@ def register(
data_input_configuration=None,
skip_model_validation=None,
source_uri=None,
model_life_cycle=None,
model_card=None,
**kwargs,
):
Expand Down Expand Up @@ -1812,6 +1813,7 @@ def register(
source_uri (str): The URI of the source for the model package (default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
``create_model()`` to accept ``**kwargs`` to customize model creation during
deploy. For more, see the implementation docs.
Expand Down Expand Up @@ -1867,6 +1869,7 @@ def register(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
)

@property
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from sagemaker.utils import to_string, format_tags
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.model_life_cycle import ModelLifeCycle

logger = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -362,6 +363,7 @@ def register(
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.
Expand Down Expand Up @@ -417,6 +419,7 @@ def register(
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -465,6 +468,7 @@ def register(
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_life_cycle=model_life_cycle,
model_card=model_card,
)

Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker import resource_requirements
from sagemaker.enums import EndpointType
from sagemaker.model_life_cycle import ModelLifeCycle


def get_default_predictor(
Expand Down Expand Up @@ -756,6 +757,7 @@ def get_register_kwargs(
data_input_configuration: Optional[str] = None,
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
config_name: Optional[str] = None,
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
accept_eula: Optional[bool] = None,
Expand Down Expand Up @@ -794,6 +796,7 @@ def get_register_kwargs(
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_life_cycle=model_life_cycle,
model_card=model_card,
accept_eula=accept_eula,
)
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.model_metrics import ModelMetrics
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.model_life_cycle import ModelLifeCycle
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements

Expand Down Expand Up @@ -863,6 +864,7 @@ def register(
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
accept_eula: Optional[bool] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -917,6 +919,7 @@ def register(
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
Returns:
A `sagemaker.model.ModelPackage` instance.
"""
Expand Down Expand Up @@ -960,6 +963,7 @@ def register(
config_name=self.config_name,
model_card=model_card,
accept_eula=accept_eula,
model_life_cycle=model_life_cycle,
)

model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
camel_to_snake,
walk_and_apply_json,
)
from sagemaker.model_life_cycle import ModelLifeCycle


class JumpStartDataHolderType:
Expand Down Expand Up @@ -2779,6 +2780,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
"data_input_configuration",
"skip_model_validation",
"source_uri",
"model_life_cycle",
"config_name",
"model_card",
"accept_eula",
Expand Down Expand Up @@ -2828,6 +2830,7 @@ def __init__(
data_input_configuration: Optional[str] = None,
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
config_name: Optional[str] = None,
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
accept_eula: Optional[bool] = None,
Expand Down
21 changes: 21 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
get_add_model_package_inference_args,
get_update_model_package_inference_args,
)
from sagemaker.model_life_cycle import ModelLifeCycle

# Setting LOGGER for backward compatibility, in case users import it...
logger = LOGGER = logging.getLogger("sagemaker")
Expand Down Expand Up @@ -473,6 +474,7 @@ def register(
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
accept_eula: Optional[bool] = None,
model_type: Optional[JumpStartModelType] = None,
):
Expand Down Expand Up @@ -528,6 +530,7 @@ def register(
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
Expand Down Expand Up @@ -597,6 +600,7 @@ def register(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
)
model_package = self.sagemaker_session.create_model_package_from_containers(
**model_pkg_args
Expand Down Expand Up @@ -2385,6 +2389,23 @@ def update_source_uri(
sagemaker_session = self.sagemaker_session or sagemaker.Session()
sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args)

def update_model_life_cycle(
self,
model_life_cycle: ModelLifeCycle,
):
"""Modellifecycle to be set for the model package

Args:
model_life_cycle (ModelLifeCycle): The current state of model package in its life cycle

"""
update_model_life_cycle_args = {
"ModelPackageArn": self.model_package_arn,
"ModelLifeCycle": model_life_cycle,
}
sagemaker_session = self.sagemaker_session or sagemaker.Session()
sagemaker_session.sagemaker_client.update_model_package(**update_model_life_cycle_args)

def remove_customer_metadata_properties(
self, customer_metadata_properties_to_remove: List[str]
):
Expand Down
51 changes: 51 additions & 0 deletions src/sagemaker/model_life_cycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This file contains code related to model life cycle."""
from __future__ import absolute_import

from typing import Optional, Union

from sagemaker.workflow.entities import PipelineVariable


class ModelLifeCycle(object):
"""Accepts ModelLifeCycle parameters for conversion to request dict."""

def __init__(
self,
stage: Optional[Union[str, PipelineVariable]] = None,
stage_status: Optional[Union[str, PipelineVariable]] = None,
stage_description: Optional[Union[str, PipelineVariable]] = None,
):
"""Initialize a ``ModelLifeCycle`` instance and turn parameters into dict.

# TODO: flesh out docstrings
Args:
stage (str or PipelineVariable):
stage_status (str or PipelineVariable):
stage_description (str or PipelineVariable):
"""
self.stage = stage
self.stage_status = stage_status
self.stage_description = stage_description

def _to_request_dict(self):
"""Generates a request dictionary using the parameters provided to the class."""
model_life_cycle_request = dict()
if self.stage:
model_life_cycle_request["Stage"] = self.stage
if self.stage_status:
model_life_cycle_request["StageStatus"] = self.stage_status
if self.stage_description:
model_life_cycle_request["StageDescription"] = self.stage_description
return model_life_cycle_request
4 changes: 4 additions & 0 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from sagemaker.utils import to_string
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.model_life_cycle import ModelLifeCycle

logger = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -182,6 +183,7 @@ def register(
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -235,6 +237,7 @@ def register(
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -276,6 +279,7 @@ def register(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
)

def prepare_container_def(
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from sagemaker.utils import to_string
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.model_life_cycle import ModelLifeCycle

logger = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -184,6 +185,7 @@ def register(
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -237,6 +239,7 @@ def register(
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -278,6 +281,7 @@ def register(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
)

def prepare_container_def(
Expand Down
Loading
Loading