Skip to content

Commit 43a9b28

Browse files
chuyang-dengDan Choi
authored andcommitted
fix: override register method in framework model class (#457)
1 parent 3ceaddb commit 43a9b28

File tree

10 files changed

+356
-6
lines changed

10 files changed

+356
-6
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@ def register(
866866
transform_instances,
867867
model_package_name,
868868
model_package_group_name,
869+
image_uri,
869870
model_metrics,
870871
marketplace_cert,
871872
approval_status,

src/sagemaker/model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def register(
112112
transform_instances,
113113
model_package_name=None,
114114
model_package_group_name=None,
115+
image_uri=None,
115116
model_metrics=None,
116117
marketplace_cert=False,
117118
approval_status=None,
@@ -131,6 +132,8 @@ def register(
131132
model_package_group_name (str): Model Package Group name, exclusive to
132133
`model_package_name`, using `model_package_group_name` makes the Model Package
133134
versioned (default: None).
135+
image_uri (str): Inference image uri for the container. Model class' self.image will
136+
be used if it is None (default: None).
134137
model_metrics (ModelMetrics): ModelMetrics object (default: None).
135138
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
136139
for AWS Marketplace (default: False).
@@ -151,6 +154,7 @@ def register(
151154
transform_instances,
152155
model_package_name,
153156
model_package_group_name,
157+
image_uri,
154158
model_metrics,
155159
marketplace_cert,
156160
approval_status,
@@ -159,7 +163,11 @@ def register(
159163
model_package = self.sagemaker_session.create_model_package_from_containers(
160164
**model_pkg_args
161165
)
162-
return model_package.get("ModelPackageArn")
166+
return ModelPackage(
167+
role=self.role,
168+
model_data=self.model_data,
169+
model_package_arn=model_package.get("ModelPackageArn"),
170+
)
163171

164172
def _get_model_package_args(
165173
self,
@@ -169,6 +177,7 @@ def _get_model_package_args(
169177
transform_instances,
170178
model_package_name=None,
171179
model_package_group_name=None,
180+
image_uri=None,
172181
model_metrics=None,
173182
marketplace_cert=False,
174183
approval_status=None,
@@ -187,6 +196,8 @@ def _get_model_package_args(
187196
model_package_group_name (str): Model Package Group name, exclusive to
188197
`model_package_name`, using `model_package_group_name` makes the Model Package
189198
versioned (default: None).
199+
image_uri (str): Inference image uri for the container. Model class' self.image will
200+
be used if it is None (default: None).
190201
model_metrics (ModelMetrics): ModelMetrics object (default: None).
191202
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
192203
for AWS Marketplace (default: False).
@@ -196,6 +207,8 @@ def _get_model_package_args(
196207
Returns:
197208
dict: A dictionary of method argument names and values.
198209
"""
210+
if image_uri:
211+
self.image_uri = image_uri
199212
container = {
200213
"Image": self.image_uri,
201214
"ModelDataUrl": self.model_data,

src/sagemaker/mxnet/model.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,72 @@ def __init__(
128128
super(MXNetModel, self).__init__(
129129
model_data, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
130130
)
131-
132131
self.model_server_workers = model_server_workers
133132

133+
def register(
134+
self,
135+
content_types,
136+
response_types,
137+
inference_instances,
138+
transform_instances,
139+
model_package_name=None,
140+
model_package_group_name=None,
141+
image_uri=None,
142+
model_metrics=None,
143+
marketplace_cert=False,
144+
approval_status=None,
145+
description=None,
146+
):
147+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
148+
149+
Args:
150+
content_types (list): The supported MIME types for the input data.
151+
response_types (list): The supported MIME types for the output data.
152+
inference_instances (list): A list of the instance types that are used to
153+
generate inferences in real-time.
154+
transform_instances (list): A list of the instance types on which a transformation
155+
job can be run or on which an endpoint can be deployed.
156+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
157+
using `model_package_name` makes the Model Package un-versioned (default: None).
158+
model_package_group_name (str): Model Package Group name, exclusive to
159+
`model_package_name`, using `model_package_group_name` makes the Model Package
160+
versioned (default: None).
161+
image_uri (str): Inference image uri for the container. Model class' self.image will
162+
be used if it is None (default: None).
163+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
164+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
165+
for AWS Marketplace (default: False).
166+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
167+
or "PendingManualApproval" (default: "PendingManualApproval").
168+
description (str): Model Package description (default: None).
169+
170+
Returns:
171+
str: A string of SageMaker Model Package ARN.
172+
"""
173+
instance_type = inference_instances[0]
174+
self._init_sagemaker_session_if_does_not_exist(instance_type)
175+
176+
if image_uri:
177+
self.image_uri = image_uri
178+
if not self.image_uri:
179+
self.image_uri = self.serving_image_uri(
180+
region_name=self.sagemaker_session.boto_session.region_name,
181+
instance_type=instance_type,
182+
)
183+
return super(MXNetModel, self).register(
184+
content_types,
185+
response_types,
186+
inference_instances,
187+
transform_instances,
188+
model_package_name,
189+
model_package_group_name,
190+
image_uri,
191+
model_metrics,
192+
marketplace_cert,
193+
approval_status,
194+
description,
195+
)
196+
134197
def prepare_container_def(self, instance_type=None, accelerator_type=None):
135198
"""Return a container definition with framework configuration set in
136199
model environment variables.

src/sagemaker/pytorch/model.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,70 @@ def __init__(
130130

131131
self.model_server_workers = model_server_workers
132132

133+
def register(
134+
self,
135+
content_types,
136+
response_types,
137+
inference_instances,
138+
transform_instances,
139+
model_package_name=None,
140+
model_package_group_name=None,
141+
image_uri=None,
142+
model_metrics=None,
143+
marketplace_cert=False,
144+
approval_status=None,
145+
description=None,
146+
):
147+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
148+
149+
Args:
150+
content_types (list): The supported MIME types for the input data.
151+
response_types (list): The supported MIME types for the output data.
152+
inference_instances (list): A list of the instance types that are used to
153+
generate inferences in real-time.
154+
transform_instances (list): A list of the instance types on which a transformation
155+
job can be run or on which an endpoint can be deployed.
156+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
157+
using `model_package_name` makes the Model Package un-versioned (default: None).
158+
model_package_group_name (str): Model Package Group name, exclusive to
159+
`model_package_name`, using `model_package_group_name` makes the Model Package
160+
versioned (default: None).
161+
image_uri (str): Inference image uri for the container. Model class' self.image will
162+
be used if it is None (default: None).
163+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
164+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
165+
for AWS Marketplace (default: False).
166+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
167+
or "PendingManualApproval" (default: "PendingManualApproval").
168+
description (str): Model Package description (default: None).
169+
170+
Returns:
171+
str: A string of SageMaker Model Package ARN.
172+
"""
173+
instance_type = inference_instances[0]
174+
self._init_sagemaker_session_if_does_not_exist(instance_type)
175+
176+
if image_uri:
177+
self.image_uri = image_uri
178+
if not self.image_uri:
179+
self.image_uri = self.serving_image_uri(
180+
region_name=self.sagemaker_session.boto_session.region_name,
181+
instance_type=instance_type,
182+
)
183+
return super(PyTorchModel, self).register(
184+
content_types,
185+
response_types,
186+
inference_instances,
187+
transform_instances,
188+
model_package_name,
189+
model_package_group_name,
190+
image_uri,
191+
model_metrics,
192+
marketplace_cert,
193+
approval_status,
194+
description,
195+
)
196+
133197
def prepare_container_def(self, instance_type=None, accelerator_type=None):
134198
"""Return a container definition with framework configuration set in
135199
model environment variables.

src/sagemaker/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2634,7 +2634,7 @@ def _get_create_model_package_request(
26342634
"SupportedTransformInstanceTypes": transform_instances,
26352635
}
26362636
request_dict["InferenceSpecification"] = inference_specification
2637-
request_dict["CertifyForMarketPlace"] = marketplace_cert
2637+
request_dict["CertifyForMarketplace"] = marketplace_cert
26382638
request_dict["ModelApprovalStatus"] = approval_status
26392639
return request_dict
26402640

src/sagemaker/tensorflow/model.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,70 @@ def __init__(
202202
)
203203
self._container_log_level = container_log_level
204204

205+
def register(
206+
self,
207+
content_types,
208+
response_types,
209+
inference_instances,
210+
transform_instances,
211+
model_package_name=None,
212+
model_package_group_name=None,
213+
image_uri=None,
214+
model_metrics=None,
215+
marketplace_cert=False,
216+
approval_status=None,
217+
description=None,
218+
):
219+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
220+
221+
Args:
222+
content_types (list): The supported MIME types for the input data.
223+
response_types (list): The supported MIME types for the output data.
224+
inference_instances (list): A list of the instance types that are used to
225+
generate inferences in real-time.
226+
transform_instances (list): A list of the instance types on which a transformation
227+
job can be run or on which an endpoint can be deployed.
228+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
229+
using `model_package_name` makes the Model Package un-versioned (default: None).
230+
model_package_group_name (str): Model Package Group name, exclusive to
231+
`model_package_name`, using `model_package_group_name` makes the Model Package
232+
versioned (default: None).
233+
image_uri (str): Inference image uri for the container. Model class' self.image will
234+
be used if it is None (default: None).
235+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
236+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
237+
for AWS Marketplace (default: False).
238+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
239+
or "PendingManualApproval" (default: "PendingManualApproval").
240+
description (str): Model Package description (default: None).
241+
242+
Returns:
243+
str: A string of SageMaker Model Package ARN.
244+
"""
245+
instance_type = inference_instances[0]
246+
self._init_sagemaker_session_if_does_not_exist(instance_type)
247+
248+
if image_uri:
249+
self.image_uri = image_uri
250+
if not self.image_uri:
251+
self.image_uri = self.serving_image_uri(
252+
region_name=self.sagemaker_session.boto_session.region_name,
253+
instance_type=instance_type,
254+
)
255+
return super(TensorFlowModel, self).register(
256+
content_types,
257+
response_types,
258+
inference_instances,
259+
transform_instances,
260+
model_package_name,
261+
model_package_group_name,
262+
image_uri,
263+
model_metrics,
264+
marketplace_cert,
265+
approval_status,
266+
description,
267+
)
268+
205269
def deploy(
206270
self,
207271
initial_instance_count,

src/sagemaker/workflow/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,8 @@ def arguments(self) -> RequestType:
305305
**model_package_args
306306
)
307307
# these are not available in the workflow service
308-
if "CertifyForMarketPlace" in request_dict:
309-
request_dict.pop("CertifyForMarketPlace")
308+
if "CertifyForMarketplace" in request_dict:
309+
request_dict.pop("CertifyForMarketplace")
310310
if "Description" in request_dict:
311311
request_dict.pop("Description")
312312
if "ModelApprovalStatus" in request_dict:

tests/integ/test_mxnet.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020

2121
import tests.integ
22+
from sagemaker import ModelPackage
2223
from sagemaker.mxnet.estimator import MXNet
2324
from sagemaker.mxnet.model import MXNetModel
2425
from sagemaker.utils import sagemaker_timestamp
@@ -160,6 +161,45 @@ def test_deploy_model(
160161
assert "Could not find model" in str(exception.value)
161162

162163

164+
def test_register_model_package(
165+
mxnet_training_job,
166+
sagemaker_session,
167+
mxnet_inference_latest_version,
168+
mxnet_inference_latest_py_version,
169+
cpu_instance_type,
170+
):
171+
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
172+
173+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
174+
desc = sagemaker_session.sagemaker_client.describe_training_job(
175+
TrainingJobName=mxnet_training_job
176+
)
177+
model_data = desc["ModelArtifacts"]["S3ModelArtifacts"]
178+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
179+
model = MXNetModel(
180+
model_data,
181+
"SageMakerRole",
182+
entry_point=script_path,
183+
py_version=mxnet_inference_latest_py_version,
184+
sagemaker_session=sagemaker_session,
185+
framework_version=mxnet_inference_latest_version,
186+
)
187+
model_package_name = "register-model-package-{}".format(sagemaker_timestamp())
188+
model_pkg = model.register(
189+
content_types=["application/json"],
190+
response_types=["application/json"],
191+
inference_instances=["ml.m5.large"],
192+
transform_instances=["ml.m5.large"],
193+
model_package_name=model_package_name,
194+
)
195+
assert isinstance(model_pkg, ModelPackage)
196+
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
197+
data = numpy.zeros(shape=(1, 1, 28, 28))
198+
result = predictor.predict(data)
199+
assert result is not None
200+
sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package_name)
201+
202+
163203
def test_deploy_model_with_tags_and_kms(
164204
mxnet_training_job,
165205
sagemaker_session,

0 commit comments

Comments
 (0)