Skip to content

Commit e8c42ae

Browse files
keshav-chandakKeshav Chandak
andauthored
feat: Model Package support for updating approval (#4134)
Co-authored-by: Keshav Chandak <[email protected]>
1 parent 16d1556 commit e8c42ae

File tree

15 files changed

+189
-34
lines changed

15 files changed

+189
-34
lines changed

src/sagemaker/chainer/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def __init__(
148148

149149
def register(
150150
self,
151-
content_types: List[Union[str, PipelineVariable]],
152-
response_types: List[Union[str, PipelineVariable]],
151+
content_types: List[Union[str, PipelineVariable]] = None,
152+
response_types: List[Union[str, PipelineVariable]] = None,
153153
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
154154
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
155155
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,8 +1665,8 @@ def deploy(
16651665

16661666
def register(
16671667
self,
1668-
content_types,
1669-
response_types,
1668+
content_types=None,
1669+
response_types=None,
16701670
inference_instances=None,
16711671
transform_instances=None,
16721672
image_uri=None,

src/sagemaker/huggingface/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ def deploy(
332332

333333
def register(
334334
self,
335-
content_types: List[Union[str, PipelineVariable]],
336-
response_types: List[Union[str, PipelineVariable]],
335+
content_types: List[Union[str, PipelineVariable]] = None,
336+
response_types: List[Union[str, PipelineVariable]] = None,
337337
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
338338
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
339339
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/model.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
4444
load_sagemaker_config,
4545
)
46+
from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum
4647
from sagemaker.session import Session
4748
from sagemaker.model_metrics import ModelMetrics
4849
from sagemaker.deprecations import removed_kwargs
@@ -374,12 +375,14 @@ def __init__(
374375
self.dependencies = updates["dependencies"]
375376
self.uploaded_code = None
376377
self.repacked_model_data = None
378+
self.content_types = None
379+
self.response_types = None
377380

378381
@runnable_by_pipeline
379382
def register(
380383
self,
381-
content_types: List[Union[str, PipelineVariable]],
382-
response_types: List[Union[str, PipelineVariable]],
384+
content_types: List[Union[str, PipelineVariable]] = None,
385+
response_types: List[Union[str, PipelineVariable]] = None,
383386
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
384387
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
385388
model_package_name: Optional[Union[str, PipelineVariable]] = None,
@@ -456,16 +459,33 @@ def register(
456459
in case the Model instance is built with
457460
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
458461
"""
459-
if self.model_data is None:
460-
raise ValueError("SageMaker Model Package cannot be created without model data.")
461462
if isinstance(self.model_data, dict):
462463
raise ValueError(
463464
"SageMaker Model Package currently cannot be created with ModelDataSource."
464465
)
465466

467+
if content_types is not None:
468+
self.content_types = content_types
469+
470+
if response_types is not None:
471+
self.response_types = response_types
472+
473+
if self.content_types is None:
474+
raise ValueError("The supported MIME types for the input data is not set")
475+
476+
if self.response_types is None:
477+
raise ValueError("The supported MIME types for the output data is not set")
478+
466479
if image_uri is not None:
467480
self.image_uri = image_uri
468481

482+
if model_package_group_name is None and model_package_name is None:
483+
# If model package group and model package name is not set
484+
# then register to auto-generated model package group
485+
model_package_group_name = utils.base_name_from_image(
486+
self.image_uri, default_base_name=ModelPackage.__name__
487+
)
488+
469489
if model_package_group_name is not None:
470490
container_def = self.prepare_container_def()
471491
container_def = update_container_with_inference_params(
@@ -478,12 +498,14 @@ def register(
478498
else:
479499
container_def = {
480500
"Image": self.image_uri,
481-
"ModelDataUrl": self.model_data,
482501
}
483502

503+
if self.model_data is not None:
504+
container_def["ModelDataUrl"] = self.model_data
505+
484506
model_pkg_args = sagemaker.get_model_package_args(
485-
content_types,
486-
response_types,
507+
self.content_types,
508+
self.response_types,
487509
inference_instances=inference_instances,
488510
transform_instances=transform_instances,
489511
model_package_name=model_package_name,
@@ -511,6 +533,7 @@ def register(
511533
role=self.role,
512534
model_data=self.model_data,
513535
model_package_arn=model_package.get("ModelPackageArn"),
536+
sagemaker_session=self.sagemaker_session,
514537
)
515538

516539
@runnable_by_pipeline
@@ -1751,6 +1774,7 @@ def __init__(
17511774

17521775
# works for MODEL_PACKAGE_ARN with or without version info.
17531776
MODEL_PACKAGE_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)(?:/(\d+))?$"
1777+
MODEL_PACKAGE_VERSIONED_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)/(\d+)$"
17541778

17551779

17561780
class ModelPackage(Model):
@@ -1885,6 +1909,18 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
18851909
self._ensure_base_name_if_needed(model_package_name)
18861910
self._set_model_name_if_needed()
18871911

1912+
# Quering the approval status for the model package
1913+
# Approving the versioned model package in case it is not approved
1914+
model_package_desc = self.sagemaker_session.sagemaker_client.describe_model_package(
1915+
ModelPackageName=self.model_package_arn or model_package_name
1916+
)
1917+
if self.model_package_arn is None:
1918+
self.model_package_arn = model_package_desc["ModelPackageArn"]
1919+
if re.match(MODEL_PACKAGE_VERSIONED_ARN_PATTERN, self.model_package_arn):
1920+
approval_status = model_package_desc.get("ModelApprovalStatus", "")
1921+
if approval_status != ModelApprovalStatusEnum.APPROVED:
1922+
self.update_approval_status(approval_status=ModelApprovalStatusEnum.APPROVED)
1923+
18881924
self.sagemaker_session.create_model(
18891925
self.name,
18901926
self.role,
@@ -1898,3 +1934,29 @@ def _ensure_base_name_if_needed(self, base_name):
18981934
"""Set the base name if there is no model name provided."""
18991935
if self.name is None:
19001936
self._base_name = base_name
1937+
1938+
def update_approval_status(self, approval_status, approval_description=None):
1939+
"""Update the approval status for the model package
1940+
1941+
Args:
1942+
approval_status (str or PipelineVariable): Model Approval Status, values can be
1943+
"Approved", "Rejected", or "PendingManualApproval".
1944+
approval_description (str): Optional. Description for the approval status of the model
1945+
(default: None).
1946+
"""
1947+
1948+
# Models can lazy-init sagemaker_session until deploy() is called to support
1949+
# LocalMode so we must make sure we have an actual session
1950+
sagemaker_session = self.sagemaker_session or sagemaker.Session()
1951+
if self.model_package_arn is None:
1952+
raise ValueError("model_package_arn is required to update the status.")
1953+
1954+
update_approval_args = {
1955+
"ModelPackageArn": self.model_package_arn,
1956+
"ModelApprovalStatus": approval_status,
1957+
}
1958+
1959+
if approval_description is not None:
1960+
update_approval_args["ApprovalDescription"] = approval_description
1961+
1962+
sagemaker_session.sagemaker_client.update_model_package(**update_approval_args)

src/sagemaker/mxnet/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ def __init__(
150150

151151
def register(
152152
self,
153-
content_types: List[Union[str, PipelineVariable]],
154-
response_types: List[Union[str, PipelineVariable]],
153+
content_types: List[Union[str, PipelineVariable]] = None,
154+
response_types: List[Union[str, PipelineVariable]] = None,
155155
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
156156
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
157157
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def _create_sagemaker_pipeline_model(self, instance_type):
335335
@runnable_by_pipeline
336336
def register(
337337
self,
338-
content_types: List[Union[str, PipelineVariable]],
339-
response_types: List[Union[str, PipelineVariable]],
338+
content_types: List[Union[str, PipelineVariable]] = None,
339+
response_types: List[Union[str, PipelineVariable]] = None,
340340
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
341341
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
342342
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/pytorch/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def __init__(
152152

153153
def register(
154154
self,
155-
content_types: List[Union[str, PipelineVariable]],
156-
response_types: List[Union[str, PipelineVariable]],
155+
content_types: List[Union[str, PipelineVariable]] = None,
156+
response_types: List[Union[str, PipelineVariable]] = None,
157157
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
158158
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
159159
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/session.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5830,8 +5830,8 @@ def wait_for_inference_recommendations_job(
58305830

58315831

58325832
def get_model_package_args(
5833-
content_types,
5834-
response_types,
5833+
content_types=None,
5834+
response_types=None,
58355835
inference_instances=None,
58365836
transform_instances=None,
58375837
model_package_name=None,
@@ -5899,19 +5899,23 @@ def get_model_package_args(
58995899
else:
59005900
container = {
59015901
"Image": image_uri,
5902-
"ModelDataUrl": model_data,
59035902
}
5903+
if model_data is not None:
5904+
container["ModelDataUrl"] = model_data
5905+
59045906
containers = [container]
59055907

59065908
model_package_args = {
59075909
"containers": containers,
5908-
"content_types": content_types,
5909-
"response_types": response_types,
59105910
"inference_instances": inference_instances,
59115911
"transform_instances": transform_instances,
59125912
"marketplace_cert": marketplace_cert,
59135913
}
59145914

5915+
if content_types is not None:
5916+
model_package_args["content_types"] = content_types
5917+
if response_types is not None:
5918+
model_package_args["response_types"] = response_types
59155919
if model_package_name is not None:
59165920
model_package_args["model_package_name"] = model_package_name
59175921
if model_package_group_name is not None:

src/sagemaker/sklearn/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def __init__(
145145

146146
def register(
147147
self,
148-
content_types: List[Union[str, PipelineVariable]],
149-
response_types: List[Union[str, PipelineVariable]],
148+
content_types: List[Union[str, PipelineVariable]] = None,
149+
response_types: List[Union[str, PipelineVariable]] = None,
150150
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
151151
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
152152
model_package_name: Optional[Union[str, PipelineVariable]] = None,

src/sagemaker/tensorflow/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def __init__(
207207

208208
def register(
209209
self,
210-
content_types: List[Union[str, PipelineVariable]],
211-
response_types: List[Union[str, PipelineVariable]],
210+
content_types: List[Union[str, PipelineVariable]] = None,
211+
response_types: List[Union[str, PipelineVariable]] = None,
212212
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
213213
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
214214
model_package_name: Optional[Union[str, PipelineVariable]] = None,

0 commit comments

Comments
 (0)