Skip to content

Commit 9661e29

Browse files
authored
fix: expose kms_key parameter for deploying from training and hyperparameter tuning jobs (#1044)
1 parent 688c8a2 commit 9661e29

File tree

4 files changed

+87
-2
lines changed

4 files changed

+87
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ def deploy(
471471
update_endpoint=False,
472472
wait=True,
473473
model_name=None,
474+
kms_key=None,
474475
**kwargs
475476
):
476477
"""Deploy the trained model to an Amazon SageMaker endpoint and return a
@@ -510,6 +511,9 @@ def deploy(
510511
For more information about tags, see
511512
https://boto3.amazonaws.com/v1/documentation\
512513
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
514+
kms_key (str): The ARN of the KMS key that is used to encrypt the
515+
data on the storage volume attached to the instance hosting the
516+
endpoint.
513517
**kwargs: Passed to invocation of ``create_model()``.
514518
Implementations may customize ``create_model()`` to accept
515519
``**kwargs`` to customize model creation during deploy.
@@ -543,6 +547,7 @@ def deploy(
543547
update_endpoint=update_endpoint,
544548
tags=self.tags,
545549
wait=wait,
550+
kms_key=kms_key,
546551
)
547552

548553
@property

src/sagemaker/tuner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ def deploy(
429429
endpoint_name=None,
430430
wait=True,
431431
model_name=None,
432+
kms_key=None,
432433
**kwargs
433434
):
434435
"""Deploy the best trained or user specified model to an Amazon
@@ -455,6 +456,9 @@ def deploy(
455456
model completes (default: True).
456457
model_name (str): Name to use for creating an Amazon SageMaker
457458
model. If not specified, the name of the training job is used.
459+
kms_key (str): The ARN of the KMS key that is used to encrypt the
460+
data on the storage volume attached to the instance hosting the
461+
endpoint.
458462
**kwargs: Other arguments needed for deployment. Please refer to the
459463
``create_model()`` method of the associated estimator to see
460464
what other arguments are needed.
@@ -475,6 +479,7 @@ def deploy(
475479
endpoint_name=endpoint_name,
476480
wait=wait,
477481
model_name=model_name,
482+
kms_key=kms_key,
478483
**kwargs
479484
)
480485

tests/unit/test_estimator.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,6 +1879,37 @@ def test_generic_to_deploy_network_isolation(sagemaker_session):
18791879
assert kwargs["enable_network_isolation"]
18801880

18811881

1882+
@patch("sagemaker.estimator.Estimator.create_model")
1883+
def test_generic_to_deploy_kms(create_model, sagemaker_session):
1884+
e = Estimator(
1885+
IMAGE_NAME,
1886+
ROLE,
1887+
INSTANCE_COUNT,
1888+
INSTANCE_TYPE,
1889+
output_path=OUTPUT_PATH,
1890+
sagemaker_session=sagemaker_session,
1891+
)
1892+
e.fit()
1893+
1894+
model = MagicMock()
1895+
create_model.return_value = model
1896+
1897+
endpoint_name = "foo"
1898+
kms_key = "key"
1899+
e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_name=endpoint_name, kms_key=kms_key)
1900+
1901+
model.deploy.assert_called_with(
1902+
instance_type=INSTANCE_TYPE,
1903+
initial_instance_count=INSTANCE_COUNT,
1904+
accelerator_type=None,
1905+
endpoint_name=endpoint_name,
1906+
update_endpoint=False,
1907+
tags=None,
1908+
wait=True,
1909+
kms_key=kms_key,
1910+
)
1911+
1912+
18821913
def test_generic_training_job_analytics(sagemaker_session):
18831914
sagemaker_session.sagemaker_client.describe_training_job = Mock(
18841915
name="describe_training_job",

tests/unit/test_tuner.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -16,7 +16,7 @@
1616
import os
1717

1818
import pytest
19-
from mock import Mock
19+
from mock import Mock, patch
2020

2121
from sagemaker import RealTimePredictor
2222
from sagemaker.amazon.amazon_estimator import RecordSet
@@ -656,6 +656,50 @@ def test_deploy_default(tuner):
656656
assert predictor.sagemaker_session == tuner.estimator.sagemaker_session
657657

658658

659+
@patch("sagemaker.estimator.Estimator.attach")
660+
@patch("sagemaker.tuner.HyperparameterTuner.best_training_job")
661+
def test_deploy_optional_params(best_training_job, estimator_attach, tuner):
662+
tuner.fit()
663+
664+
estimator = Mock()
665+
estimator_attach.return_value = estimator
666+
667+
training_job = "best-job-ever"
668+
best_training_job.return_value = training_job
669+
670+
accelerator = "ml.eia1.medium"
671+
endpoint_name = "foo"
672+
model_name = "bar"
673+
kms_key = "key"
674+
kwargs = {"some_arg": "some_value"}
675+
676+
tuner.deploy(
677+
TRAIN_INSTANCE_COUNT,
678+
TRAIN_INSTANCE_TYPE,
679+
accelerator_type=accelerator,
680+
endpoint_name=endpoint_name,
681+
wait=False,
682+
model_name=model_name,
683+
kms_key=kms_key,
684+
**kwargs
685+
)
686+
687+
estimator_attach.assert_called_with(
688+
training_job, sagemaker_session=tuner.estimator.sagemaker_session
689+
)
690+
691+
estimator.deploy.assert_called_with(
692+
TRAIN_INSTANCE_COUNT,
693+
TRAIN_INSTANCE_TYPE,
694+
accelerator_type=accelerator,
695+
endpoint_name=endpoint_name,
696+
wait=False,
697+
model_name=model_name,
698+
kms_key=kms_key,
699+
**kwargs
700+
)
701+
702+
659703
def test_wait(tuner):
660704
tuner.latest_tuning_job = _TuningJob(tuner.estimator.sagemaker_session, JOB_NAME)
661705
tuner.estimator.sagemaker_session.wait_for_tuning_job = Mock(name="wait_for_tuning_job")

0 commit comments

Comments
 (0)