Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from sagemaker.utils import to_string
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.model_life_cycle import ModelLifeCycle

logger = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -180,6 +181,7 @@ def register(
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -233,6 +235,7 @@ def register(
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).

Returns:
str: A string of SageMaker Model Package ARN.
Expand Down Expand Up @@ -274,6 +277,7 @@ def register(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
)

def prepare_container_def(
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,6 +1761,7 @@ def register(
data_input_configuration=None,
skip_model_validation=None,
source_uri=None,
model_life_cycle=None,
model_card=None,
**kwargs,
):
Expand Down Expand Up @@ -1812,6 +1813,7 @@ def register(
source_uri (str): The URI of the source for the model package (default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
``create_model()`` to accept ``**kwargs`` to customize model creation during
deploy. For more, see the implementation docs.
Expand Down Expand Up @@ -1867,6 +1869,7 @@ def register(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
)

@property
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from sagemaker.utils import to_string, format_tags
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.model_life_cycle import ModelLifeCycle

logger = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -362,6 +363,7 @@ def register(
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.
Expand Down Expand Up @@ -417,6 +419,7 @@ def register(
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -465,6 +468,7 @@ def register(
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_life_cycle=model_life_cycle,
model_card=model_card,
)

Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker import resource_requirements
from sagemaker.enums import EndpointType
from sagemaker.model_life_cycle import ModelLifeCycle


def get_default_predictor(
Expand Down Expand Up @@ -756,6 +757,7 @@ def get_register_kwargs(
data_input_configuration: Optional[str] = None,
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
config_name: Optional[str] = None,
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
accept_eula: Optional[bool] = None,
Expand Down Expand Up @@ -794,6 +796,7 @@ def get_register_kwargs(
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_life_cycle=model_life_cycle,
model_card=model_card,
accept_eula=accept_eula,
)
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.model_metrics import ModelMetrics
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.model_life_cycle import ModelLifeCycle
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements

Expand Down Expand Up @@ -863,6 +864,7 @@ def register(
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
accept_eula: Optional[bool] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -917,6 +919,7 @@ def register(
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
Returns:
A `sagemaker.model.ModelPackage` instance.
"""
Expand Down Expand Up @@ -960,6 +963,7 @@ def register(
config_name=self.config_name,
model_card=model_card,
accept_eula=accept_eula,
model_life_cycle=model_life_cycle,
)

model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
camel_to_snake,
walk_and_apply_json,
)
from sagemaker.model_life_cycle import ModelLifeCycle


class JumpStartDataHolderType:
Expand Down Expand Up @@ -2779,6 +2780,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
"data_input_configuration",
"skip_model_validation",
"source_uri",
"model_life_cycle",
"config_name",
"model_card",
"accept_eula",
Expand Down Expand Up @@ -2828,6 +2830,7 @@ def __init__(
data_input_configuration: Optional[str] = None,
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
config_name: Optional[str] = None,
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
accept_eula: Optional[bool] = None,
Expand Down
21 changes: 21 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
get_add_model_package_inference_args,
get_update_model_package_inference_args,
)
from sagemaker.model_life_cycle import ModelLifeCycle

# Setting LOGGER for backward compatibility, in case users import it...
logger = LOGGER = logging.getLogger("sagemaker")
Expand Down Expand Up @@ -473,6 +474,7 @@ def register(
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
accept_eula: Optional[bool] = None,
model_type: Optional[JumpStartModelType] = None,
):
Expand Down Expand Up @@ -528,6 +530,7 @@ def register(
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
Expand Down Expand Up @@ -597,6 +600,7 @@ def register(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
)
model_package = self.sagemaker_session.create_model_package_from_containers(
**model_pkg_args
Expand Down Expand Up @@ -2385,6 +2389,23 @@ def update_source_uri(
sagemaker_session = self.sagemaker_session or sagemaker.Session()
sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args)

def update_model_life_cycle(
self,
model_life_cycle: ModelLifeCycle,
):
"""ModelLifeCycle to be set for the model package

Args:
model_life_cycle (ModelLifeCycle): The current state of model package in its life cycle

"""
update_model_life_cycle_args = {
"ModelPackageArn": self.model_package_arn,
"ModelLifeCycle": model_life_cycle,
}
sagemaker_session = self.sagemaker_session or sagemaker.Session()
sagemaker_session.sagemaker_client.update_model_package(**update_model_life_cycle_args)

def remove_customer_metadata_properties(
self, customer_metadata_properties_to_remove: List[str]
):
Expand Down
51 changes: 51 additions & 0 deletions src/sagemaker/model_life_cycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This file contains code related to model life cycle."""
from __future__ import absolute_import

from typing import Optional, Union

from sagemaker.workflow.entities import PipelineVariable


class ModelLifeCycle(object):
"""Accepts ModelLifeCycle parameters for conversion to request dict."""

def __init__(
self,
stage: Optional[Union[str, PipelineVariable]] = None,
stage_status: Optional[Union[str, PipelineVariable]] = None,
stage_description: Optional[Union[str, PipelineVariable]] = None,
):
"""Initialize a ``ModelLifeCycle`` instance and turn parameters into dict.

# TODO: flesh out docstrings
Args:
stage (str or PipelineVariable):
stage_status (str or PipelineVariable):
stage_description (str or PipelineVariable):
"""
self.stage = stage
self.stage_status = stage_status
self.stage_description = stage_description

def _to_request_dict(self):
"""Generates a request dictionary using the parameters provided to the class."""
model_life_cycle_request = dict()
if self.stage:
model_life_cycle_request["Stage"] = self.stage
if self.stage_status:
model_life_cycle_request["StageStatus"] = self.stage_status
if self.stage_description:
model_life_cycle_request["StageDescription"] = self.stage_description
return model_life_cycle_request
4 changes: 4 additions & 0 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from sagemaker.utils import to_string
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.model_life_cycle import ModelLifeCycle

logger = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -182,6 +183,7 @@ def register(
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -235,6 +237,7 @@ def register(
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -276,6 +279,7 @@ def register(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
)

def prepare_container_def(
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from sagemaker.utils import to_string
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.model_life_cycle import ModelLifeCycle

logger = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -184,6 +185,7 @@ def register(
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
model_life_cycle: Optional[ModelLifeCycle] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -237,6 +239,7 @@ def register(
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).

Returns:
A `sagemaker.model.ModelPackage` instance.
Expand Down Expand Up @@ -278,6 +281,7 @@ def register(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
)

def prepare_container_def(
Expand Down
Loading
Loading