Skip to content

Commit 90dbb1a

Browse files
authored
change: use Estimator.create_model in Estimator.transformer (#1041)
1 parent 3a9c336 commit 90dbb1a

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

src/sagemaker/estimator.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -705,16 +705,23 @@ def transformer(
705705
"""
706706
tags = tags or self.tags
707707

708-
if self.latest_training_job is not None:
709-
model_name = self.sagemaker_session.create_model_from_job(
710-
self.latest_training_job.name, role=role, tags=tags
711-
)
712-
else:
708+
if self.latest_training_job is None:
713709
logging.warning(
714710
"No finished training job found associated with this estimator. Please make sure"
715711
"this estimator is only used for building workflow config"
716712
)
717713
model_name = self._current_job_name
714+
else:
715+
model_name = self.latest_training_job.name
716+
717+
model = self.create_model()
718+
719+
# not all create_model() implementations have the same kwargs
720+
model.name = model_name
721+
if role is not None:
722+
model.role = role
723+
724+
model._create_sagemaker_model(instance_type, tags=tags)
718725

719726
return Transformer(
720727
model_name,

tests/integ/test_transformer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,15 @@ def test_transform_byo_estimator(sagemaker_session, cpu_instance_type):
279279
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
280280
kmeans.fit(records, job_name=job_name)
281281

282+
estimator = Estimator.attach(training_job_name=job_name, sagemaker_session=sagemaker_session)
283+
estimator._enable_network_isolation = True
284+
282285
transform_input_path = os.path.join(data_path, "transform_input.csv")
283286
transform_input_key_prefix = "integ-test-data/one_p_mnist/transform"
284287
transform_input = kmeans.sagemaker_session.upload_data(
285288
path=transform_input_path, key_prefix=transform_input_key_prefix
286289
)
287290

288-
estimator = Estimator.attach(training_job_name=job_name, sagemaker_session=sagemaker_session)
289-
290291
transformer = estimator.transformer(1, cpu_instance_type, tags=tags)
291292
transformer.transform(transform_input, content_type="text/csv")
292293

@@ -297,6 +298,8 @@ def test_transform_byo_estimator(sagemaker_session, cpu_instance_type):
297298
model_desc = sagemaker_session.sagemaker_client.describe_model(
298299
ModelName=transformer.model_name
299300
)
301+
assert model_desc["EnableNetworkIsolation"]
302+
300303
model_tags = sagemaker_session.sagemaker_client.list_tags(
301304
ResourceArn=model_desc["ModelArn"]
302305
)["Tags"]

tests/unit/test_estimator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,7 @@ def test_ensure_latest_training_job_failure(sagemaker_session):
13921392
assert "Estimator is not associated with a training job" in str(e)
13931393

13941394

1395+
@patch("sagemaker.estimator.Estimator.create_model", return_value=Mock())
13951396
def test_estimator_transformer_creation(sagemaker_session):
13961397
estimator = Estimator(
13971398
image_name=IMAGE_NAME,
@@ -1401,11 +1402,9 @@ def test_estimator_transformer_creation(sagemaker_session):
14011402
sagemaker_session=sagemaker_session,
14021403
)
14031404
estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
1404-
sagemaker_session.create_model_from_job.return_value = JOB_NAME
14051405

14061406
transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
14071407

1408-
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None, tags=None)
14091408
assert isinstance(transformer, Transformer)
14101409
assert transformer.sagemaker_session == sagemaker_session
14111410
assert transformer.instance_count == INSTANCE_COUNT
@@ -1414,6 +1413,7 @@ def test_estimator_transformer_creation(sagemaker_session):
14141413
assert transformer.tags is None
14151414

14161415

1416+
@patch("sagemaker.estimator.Estimator.create_model", return_value=Mock())
14171417
def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
14181418
base_name = "foo"
14191419
estimator = Estimator(
@@ -1425,7 +1425,6 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
14251425
base_job_name=base_name,
14261426
)
14271427
estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
1428-
sagemaker_session.create_model_from_job.return_value = JOB_NAME
14291428

14301429
strategy = "MultiRecord"
14311430
assemble_with = "Line"
@@ -1450,7 +1449,6 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
14501449
role=ROLE,
14511450
)
14521451

1453-
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE, tags=TAGS)
14541452
assert transformer.strategy == strategy
14551453
assert transformer.assemble_with == assemble_with
14561454
assert transformer.output_path == OUTPUT_PATH

0 commit comments

Comments
 (0)