Skip to content

Commit 3d8ffb8

Browse files
Feature: Added support of ModelLifeCycle construct in Model Registry (#4919)
* feature: Added support of ModelLifeCycle construct in Model Registry * Fixed the docstyle error --------- Co-authored-by: parknate@ <[email protected]>
1 parent 292a00d commit 3d8ffb8

File tree

21 files changed

+417
-1
lines changed

21 files changed

+417
-1
lines changed

src/sagemaker/chainer/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from sagemaker.utils import to_string
3838
from sagemaker.workflow import is_pipeline_variable
3939
from sagemaker.workflow.entities import PipelineVariable
40+
from sagemaker.model_life_cycle import ModelLifeCycle
4041

4142
logger = logging.getLogger("sagemaker")
4243

@@ -180,6 +181,7 @@ def register(
180181
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
181182
source_uri: Optional[Union[str, PipelineVariable]] = None,
182183
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
184+
model_life_cycle: Optional[ModelLifeCycle] = None,
183185
):
184186
"""Creates a model package for creating SageMaker models or listing on Marketplace.
185187
@@ -233,6 +235,7 @@ def register(
233235
(default: None).
234236
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
235237
quantitative information about a model (default: None).
238+
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
236239
237240
Returns:
238241
str: A string of SageMaker Model Package ARN.
@@ -274,6 +277,7 @@ def register(
274277
skip_model_validation=skip_model_validation,
275278
source_uri=source_uri,
276279
model_card=model_card,
280+
model_life_cycle=model_life_cycle,
277281
)
278282

279283
def prepare_container_def(

src/sagemaker/estimator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,6 +1761,7 @@ def register(
17611761
data_input_configuration=None,
17621762
skip_model_validation=None,
17631763
source_uri=None,
1764+
model_life_cycle=None,
17641765
model_card=None,
17651766
**kwargs,
17661767
):
@@ -1812,6 +1813,7 @@ def register(
18121813
source_uri (str): The URI of the source for the model package (default: None).
18131814
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
18141815
quantitative information about a model (default: None).
1816+
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
18151817
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
18161818
``create_model()`` to accept ``**kwargs`` to customize model creation during
18171819
deploy. For more, see the implementation docs.
@@ -1867,6 +1869,7 @@ def register(
18671869
skip_model_validation=skip_model_validation,
18681870
source_uri=source_uri,
18691871
model_card=model_card,
1872+
model_life_cycle=model_life_cycle,
18701873
)
18711874

18721875
@property

src/sagemaker/huggingface/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from sagemaker.utils import to_string, format_tags
3737
from sagemaker.workflow import is_pipeline_variable
3838
from sagemaker.workflow.entities import PipelineVariable
39+
from sagemaker.model_life_cycle import ModelLifeCycle
3940

4041
logger = logging.getLogger("sagemaker")
4142

@@ -362,6 +363,7 @@ def register(
362363
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
363364
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
364365
source_uri: Optional[Union[str, PipelineVariable]] = None,
366+
model_life_cycle: Optional[ModelLifeCycle] = None,
365367
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
366368
):
367369
"""Creates a model package for creating SageMaker models or listing on Marketplace.
@@ -417,6 +419,7 @@ def register(
417419
(default: None).
418420
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
419421
quantitative information about a model (default: None).
422+
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
420423
421424
Returns:
422425
A `sagemaker.model.ModelPackage` instance.
@@ -465,6 +468,7 @@ def register(
465468
data_input_configuration=data_input_configuration,
466469
skip_model_validation=skip_model_validation,
467470
source_uri=source_uri,
471+
model_life_cycle=model_life_cycle,
468472
model_card=model_card,
469473
)
470474

src/sagemaker/jumpstart/factory/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
7676
from sagemaker import resource_requirements
7777
from sagemaker.enums import EndpointType
78+
from sagemaker.model_life_cycle import ModelLifeCycle
7879

7980

8081
def get_default_predictor(
@@ -756,6 +757,7 @@ def get_register_kwargs(
756757
data_input_configuration: Optional[str] = None,
757758
skip_model_validation: Optional[str] = None,
758759
source_uri: Optional[str] = None,
760+
model_life_cycle: Optional[ModelLifeCycle] = None,
759761
config_name: Optional[str] = None,
760762
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
761763
accept_eula: Optional[bool] = None,
@@ -794,6 +796,7 @@ def get_register_kwargs(
794796
data_input_configuration=data_input_configuration,
795797
skip_model_validation=skip_model_validation,
796798
source_uri=source_uri,
799+
model_life_cycle=model_life_cycle,
797800
model_card=model_card,
798801
accept_eula=accept_eula,
799802
)

src/sagemaker/jumpstart/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
from sagemaker.workflow.entities import PipelineVariable
7171
from sagemaker.model_metrics import ModelMetrics
7272
from sagemaker.metadata_properties import MetadataProperties
73+
from sagemaker.model_life_cycle import ModelLifeCycle
7374
from sagemaker.drift_check_baselines import DriftCheckBaselines
7475
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
7576

@@ -863,6 +864,7 @@ def register(
863864
source_uri: Optional[Union[str, PipelineVariable]] = None,
864865
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
865866
accept_eula: Optional[bool] = None,
867+
model_life_cycle: Optional[ModelLifeCycle] = None,
866868
):
867869
"""Creates a model package for creating SageMaker models or listing on Marketplace.
868870
@@ -917,6 +919,7 @@ def register(
917919
The `accept_eula` value must be explicitly defined as `True` in order to
918920
accept the end-user license agreement (EULA) that some
919921
models require. (Default: None).
922+
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
920923
Returns:
921924
A `sagemaker.model.ModelPackage` instance.
922925
"""
@@ -960,6 +963,7 @@ def register(
960963
config_name=self.config_name,
961964
model_card=model_card,
962965
accept_eula=accept_eula,
966+
model_life_cycle=model_life_cycle,
963967
)
964968

965969
model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())

src/sagemaker/jumpstart/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
camel_to_snake,
4343
walk_and_apply_json,
4444
)
45+
from sagemaker.model_life_cycle import ModelLifeCycle
4546

4647

4748
class JumpStartDataHolderType:
@@ -2779,6 +2780,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
27792780
"data_input_configuration",
27802781
"skip_model_validation",
27812782
"source_uri",
2783+
"model_life_cycle",
27822784
"config_name",
27832785
"model_card",
27842786
"accept_eula",
@@ -2828,6 +2830,7 @@ def __init__(
28282830
data_input_configuration: Optional[str] = None,
28292831
skip_model_validation: Optional[str] = None,
28302832
source_uri: Optional[str] = None,
2833+
model_life_cycle: Optional[ModelLifeCycle] = None,
28312834
config_name: Optional[str] = None,
28322835
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
28332836
accept_eula: Optional[bool] = None,

src/sagemaker/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
get_add_model_package_inference_args,
9191
get_update_model_package_inference_args,
9292
)
93+
from sagemaker.model_life_cycle import ModelLifeCycle
9394

9495
# Setting LOGGER for backward compatibility, in case users import it...
9596
logger = LOGGER = logging.getLogger("sagemaker")
@@ -473,6 +474,7 @@ def register(
473474
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
474475
source_uri: Optional[Union[str, PipelineVariable]] = None,
475476
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
477+
model_life_cycle: Optional[ModelLifeCycle] = None,
476478
accept_eula: Optional[bool] = None,
477479
model_type: Optional[JumpStartModelType] = None,
478480
):
@@ -528,6 +530,7 @@ def register(
528530
(default: None).
529531
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
530532
quantitative information about a model (default: None).
533+
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
531534
532535
Returns:
533536
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
@@ -597,6 +600,7 @@ def register(
597600
skip_model_validation=skip_model_validation,
598601
source_uri=source_uri,
599602
model_card=model_card,
603+
model_life_cycle=model_life_cycle,
600604
)
601605
model_package = self.sagemaker_session.create_model_package_from_containers(
602606
**model_pkg_args
@@ -2385,6 +2389,23 @@ def update_source_uri(
23852389
sagemaker_session = self.sagemaker_session or sagemaker.Session()
23862390
sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args)
23872391

2392+
def update_model_life_cycle(
2393+
self,
2394+
model_life_cycle: ModelLifeCycle,
2395+
):
2396+
"""Modellifecycle to be set for the model package
2397+
2398+
Args:
2399+
model_life_cycle (ModelLifeCycle): The current state of model package in its life cycle
2400+
2401+
"""
2402+
update_model_life_cycle_args = {
2403+
"ModelPackageArn": self.model_package_arn,
2404+
"ModelLifeCycle": model_life_cycle,
2405+
}
2406+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
2407+
sagemaker_session.sagemaker_client.update_model_package(**update_model_life_cycle_args)
2408+
23882409
def remove_customer_metadata_properties(
23892410
self, customer_metadata_properties_to_remove: List[str]
23902411
):

src/sagemaker/model_life_cycle.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This file contains code related to model life cycle."""
14+
from __future__ import absolute_import
15+
16+
from typing import Optional, Union
17+
18+
from sagemaker.workflow.entities import PipelineVariable
19+
20+
21+
class ModelLifeCycle(object):
22+
"""Accepts ModelLifeCycle parameters for conversion to request dict."""
23+
24+
def __init__(
25+
self,
26+
stage: Optional[Union[str, PipelineVariable]] = None,
27+
stage_status: Optional[Union[str, PipelineVariable]] = None,
28+
stage_description: Optional[Union[str, PipelineVariable]] = None,
29+
):
30+
"""Initialize a ``ModelLifeCycle`` instance and turn parameters into dict.
31+
32+
# TODO: flesh out docstrings
33+
Args:
34+
stage (str or PipelineVariable):
35+
stage_status (str or PipelineVariable):
36+
stage_description (str or PipelineVariable):
37+
"""
38+
self.stage = stage
39+
self.stage_status = stage_status
40+
self.stage_description = stage_description
41+
42+
def _to_request_dict(self):
43+
"""Generates a request dictionary using the parameters provided to the class."""
44+
model_life_cycle_request = dict()
45+
if self.stage:
46+
model_life_cycle_request["Stage"] = self.stage
47+
if self.stage_status:
48+
model_life_cycle_request["StageStatus"] = self.stage_status
49+
if self.stage_description:
50+
model_life_cycle_request["StageDescription"] = self.stage_description
51+
return model_life_cycle_request

src/sagemaker/mxnet/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from sagemaker.utils import to_string
4040
from sagemaker.workflow import is_pipeline_variable
4141
from sagemaker.workflow.entities import PipelineVariable
42+
from sagemaker.model_life_cycle import ModelLifeCycle
4243

4344
logger = logging.getLogger("sagemaker")
4445

@@ -182,6 +183,7 @@ def register(
182183
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
183184
source_uri: Optional[Union[str, PipelineVariable]] = None,
184185
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
186+
model_life_cycle: Optional[ModelLifeCycle] = None,
185187
):
186188
"""Creates a model package for creating SageMaker models or listing on Marketplace.
187189
@@ -235,6 +237,7 @@ def register(
235237
(default: None).
236238
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
237239
quantitative information about a model (default: None).
240+
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
238241
239242
Returns:
240243
A `sagemaker.model.ModelPackage` instance.
@@ -276,6 +279,7 @@ def register(
276279
skip_model_validation=skip_model_validation,
277280
source_uri=source_uri,
278281
model_card=model_card,
282+
model_life_cycle=model_life_cycle,
279283
)
280284

281285
def prepare_container_def(

src/sagemaker/pytorch/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from sagemaker.utils import to_string
4040
from sagemaker.workflow import is_pipeline_variable
4141
from sagemaker.workflow.entities import PipelineVariable
42+
from sagemaker.model_life_cycle import ModelLifeCycle
4243

4344
logger = logging.getLogger("sagemaker")
4445

@@ -184,6 +185,7 @@ def register(
184185
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
185186
source_uri: Optional[Union[str, PipelineVariable]] = None,
186187
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
188+
model_life_cycle: Optional[ModelLifeCycle] = None,
187189
):
188190
"""Creates a model package for creating SageMaker models or listing on Marketplace.
189191
@@ -237,6 +239,7 @@ def register(
237239
(default: None).
238240
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
239241
quantitative information about a model (default: None).
242+
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
240243
241244
Returns:
242245
A `sagemaker.model.ModelPackage` instance.
@@ -278,6 +281,7 @@ def register(
278281
skip_model_validation=skip_model_validation,
279282
source_uri=source_uri,
280283
model_card=model_card,
284+
model_life_cycle=model_life_cycle,
281285
)
282286

283287
def prepare_container_def(

0 commit comments

Comments
 (0)