Skip to content

Commit d61dcb1

Browse files
committed
Fix for Estimator Training details on register
1 parent 3d4d258 commit d61dcb1

File tree

3 files changed

+75
-2
lines changed

3 files changed

+75
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from sagemaker.interactive_apps import SupportedInteractiveAppTypes
6969
from sagemaker.interactive_apps.tensorboard import TensorBoardApp
7070
from sagemaker.instance_group import InstanceGroup
71+
from sagemaker.model_card.model_card import ModelCard, TrainingDetails
7172
from sagemaker.utils import instance_supports_kms
7273
from sagemaker.job import _Job
7374
from sagemaker.jumpstart.utils import (
@@ -1797,8 +1798,16 @@ def register(
17971798
else:
17981799
if "model_kms_key" not in kwargs:
17991800
kwargs["model_kms_key"] = self.output_kms_key
1800-
model = self.create_model(image_uri=image_uri, **kwargs)
1801+
model = self.create_model(image_uri=image_uri, name=model_name, **kwargs)
18011802
model.name = model_name
1803+
if self.model_data is not None and model_card is None:
1804+
training_details = TrainingDetails.from_model_s3_artifacts(
1805+
model_artifacts=[self.model_data], sagemaker_session=self.sagemaker_session
1806+
)
1807+
model_card = ModelCard(
1808+
name="estimator_card",
1809+
training_details=training_details,
1810+
)
18021811
return model.register(
18031812
content_types,
18041813
response_types,

src/sagemaker/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,10 @@ def register(
549549
model_package_group_name = utils.base_name_from_image(
550550
self.image_uri, default_base_name=ModelPackage.__name__
551551
)
552-
if model_package_group_name is not None:
552+
if (
553+
model_package_group_name is not None
554+
and self.model_type is not JumpStartModelType.PROPRIETARY
555+
):
553556
container_def = self.prepare_container_def(accept_eula=accept_eula)
554557
container_def = update_container_with_inference_params(
555558
framework=framework,

tests/integ/test_byo_estimator.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,20 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import io
1516
import json
1617
import os
1718

19+
import numpy as np
20+
1821
import pytest
22+
import sagemaker.amazon.common as smac
23+
1924

2025
import sagemaker
2126
from sagemaker import image_uris
2227
from sagemaker.estimator import Estimator
28+
from sagemaker.s3 import S3Uploader
2329
from sagemaker.serializers import SimpleBaseSerializer
2430
from sagemaker.utils import unique_name_from_base
2531
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, datasets
@@ -102,6 +108,61 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_se
102108
assert prediction["score"] is not None
103109

104110

111+
def test_estimator_register_publish_training_details(
112+
sagemaker_session, region, cpu_instance_type, training_set
113+
):
114+
115+
bucket = sagemaker_session.default_bucket()
116+
prefix = "model-card-sample-notebook"
117+
118+
raw_data = (
119+
(0.5, 0),
120+
(0.75, 0),
121+
(1.0, 0),
122+
(1.25, 0),
123+
(1.50, 0),
124+
(1.75, 0),
125+
(2.0, 0),
126+
(2.25, 1),
127+
(2.5, 0),
128+
(2.75, 1),
129+
(3.0, 0),
130+
(3.25, 1),
131+
(3.5, 0),
132+
(4.0, 1),
133+
(4.25, 1),
134+
(4.5, 1),
135+
(4.75, 1),
136+
(5.0, 1),
137+
(5.5, 1),
138+
)
139+
training_data = np.array(raw_data).astype("float32")
140+
labels = training_data[:, 1]
141+
142+
# upload data to S3 bucket
143+
buf = io.BytesIO()
144+
smac.write_numpy_to_dense_tensor(buf, training_data, labels)
145+
buf.seek(0)
146+
s3_train_data = f"s3://{bucket}/{prefix}/train"
147+
S3Uploader.upload_bytes(b=buf, s3_uri=s3_train_data, sagemaker_session=sagemaker_session)
148+
output_location = f"s3://{bucket}/{prefix}/output"
149+
container = image_uris.retrieve("linear-learner", region)
150+
estimator = sagemaker.estimator.Estimator(
151+
container,
152+
role="SageMakerRole",
153+
instance_count=1,
154+
instance_type="ml.m4.xlarge",
155+
output_path=output_location,
156+
sagemaker_session=sagemaker_session,
157+
)
158+
estimator.set_hyperparameters(
159+
feature_dim=2, mini_batch_size=10, predictor_type="binary_classifier"
160+
)
161+
estimator.fit({"train": s3_train_data})
162+
print(f"Training job name: {estimator.latest_training_job.name}")
163+
estimator.register()
164+
165+
105166
def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, training_set):
106167
image_uri = image_uris.retrieve("factorization-machines", region)
107168
endpoint_name = unique_name_from_base("byo")

0 commit comments

Comments
 (0)