Skip to content

Commit 76d46d0

Browse files
authored
fix: enable kms support for repack_model (#1061)
* fix: enable kms support for repack_model Currently repack_model doesn't accept a kms key. This change added a kms_key argument to the fucntion. In addition repack_model will always use the output_kms_key inside the Estimator if it's set.
1 parent d368524 commit 76d46d0

File tree

17 files changed

+69
-13
lines changed

17 files changed

+69
-13
lines changed

src/sagemaker/amazon/kmeans.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def __init__(
148148
self.center_factor = center_factor
149149
self.eval_metrics = eval_metrics
150150

151-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
151+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
152152
"""Return a :class:`~sagemaker.amazon.kmeans.KMeansModel` referencing
153153
the latest s3 model data produced by this Estimator.
154154
@@ -158,12 +158,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
158158
Default: use subnets and security groups from this Estimator.
159159
* 'Subnets' (list[str]): List of subnet ids.
160160
* 'SecurityGroupIds' (list[str]): List of security group ids.
161+
**kwargs: Additional kwargs passed to the KMeansModel constructor.
161162
"""
162163
return KMeansModel(
163164
self.model_data,
164165
self.role,
165166
self.sagemaker_session,
166167
vpc_config=self.get_vpc_config(vpc_config_override),
168+
**kwargs
167169
)
168170

169171
def _prepare_for_training(self, records, mini_batch_size=5000, job_name=None):

src/sagemaker/amazon/lda.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(
122122
self.max_iterations = max_iterations
123123
self.tol = tol
124124

125-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
125+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
126126
"""Return a :class:`~sagemaker.amazon.LDAModel` referencing the latest
127127
s3 model data produced by this Estimator.
128128
@@ -132,12 +132,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
132132
Default: use subnets and security groups from this Estimator.
133133
* 'Subnets' (list[str]): List of subnet ids.
134134
* 'SecurityGroupIds' (list[str]): List of security group ids.
135+
**kwargs: Additional kwargs passed to the LDAModel constructor.
135136
"""
136137
return LDAModel(
137138
self.model_data,
138139
self.role,
139140
sagemaker_session=self.sagemaker_session,
140141
vpc_config=self.get_vpc_config(vpc_config_override),
142+
**kwargs
141143
)
142144

143145
def _prepare_for_training( # pylint: disable=signature-differs

src/sagemaker/amazon/linear_learner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def __init__(
373373
"value greater than 2."
374374
)
375375

376-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
376+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
377377
"""Return a :class:`~sagemaker.amazon.LinearLearnerModel` referencing
378378
the latest s3 model data produced by this Estimator.
379379
@@ -382,12 +382,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
382382
the model. Default: use subnets and security groups from this Estimator.
383383
* 'Subnets' (list[str]): List of subnet ids.
384384
* 'SecurityGroupIds' (list[str]): List of security group ids.
385+
**kwargs: Additional kwargs passed to the LinearLearnerModel constructor.
385386
"""
386387
return LinearLearnerModel(
387388
self.model_data,
388389
self.role,
389390
self.sagemaker_session,
390391
vpc_config=self.get_vpc_config(vpc_config_override),
392+
**kwargs
391393
)
392394

393395
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):

src/sagemaker/chainer/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def create_model(
162162
entry_point=None,
163163
source_dir=None,
164164
dependencies=None,
165+
**kwargs
165166
):
166167
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an
167168
``Endpoint``.
@@ -186,6 +187,7 @@ def create_model(
186187
dependencies (list[str]): A list of paths to directories (absolute or relative) with
187188
any additional libraries that will be exported to the container.
188189
If not specified, the dependencies from training are used.
190+
**kwargs: Additional kwargs passed to the ChainerModel constructor.
189191
190192
Returns:
191193
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel``

src/sagemaker/estimator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ def deploy(
547547
)
548548
model = self._compiled_models[family]
549549
else:
550+
kwargs["model_kms_key"] = self.output_kms_key
550551
model = self.create_model(**kwargs)
551552
model.name = model_name
552553
return model.deploy(
@@ -734,7 +735,9 @@ def transformer(
734735
model_name = self._current_job_name
735736
else:
736737
model_name = self.latest_training_job.name
737-
model = self.create_model(vpc_config_override=vpc_config_override)
738+
model = self.create_model(
739+
vpc_config_override=vpc_config_override, model_kms_key=self.output_kms_key
740+
)
738741

739742
# not all create_model() implementations have the same kwargs
740743
model.name = model_name
@@ -1716,6 +1719,7 @@ def transformer(
17161719
model_server_workers=model_server_workers,
17171720
entry_point=entry_point,
17181721
vpc_config_override=vpc_config_override,
1722+
model_kms_key=self.output_kms_key,
17191723
)
17201724
model._create_sagemaker_model(instance_type, tags=tags)
17211725

src/sagemaker/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
vpc_config=None,
8080
sagemaker_session=None,
8181
enable_network_isolation=False,
82+
model_kms_key=None,
8283
):
8384
"""Initialize an SageMaker ``Model``.
8485
@@ -114,6 +115,8 @@ def __init__(
114115
network isolation in the endpoint, isolating the model
115116
container. No inbound or outbound network calls can be made to
116117
or from the model container.
118+
model_kms_key (str): KMS key ARN used to encrypt the repacked
119+
model archive file if the model is repacked
117120
"""
118121
self.model_data = model_data
119122
self.image = image
@@ -127,6 +130,7 @@ def __init__(
127130
self.endpoint_name = None
128131
self._is_compiled_model = False
129132
self._enable_network_isolation = enable_network_isolation
133+
self.model_kms_key = model_kms_key
130134

131135
def prepare_container_def(
132136
self, instance_type, accelerator_type=None
@@ -799,6 +803,7 @@ def _upload_code(self, key_prefix, repack=False):
799803
model_uri=self.model_data,
800804
repacked_model_uri=repacked_model_data,
801805
sagemaker_session=self.sagemaker_session,
806+
kms_key=self.model_kms_key,
802807
)
803808

804809
self.repacked_model_data = repacked_model_data

src/sagemaker/mxnet/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def create_model(
141141
source_dir=None,
142142
dependencies=None,
143143
image_name=None,
144+
**kwargs
144145
):
145146
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an
146147
``Endpoint``.
@@ -171,6 +172,7 @@ def create_model(
171172
Examples:
172173
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
173174
custom-image:latest.
175+
**kwargs: Additional kwargs passed to the MXNetModel constructor.
174176
175177
Returns:
176178
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.

src/sagemaker/pytorch/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def create_model(
115115
entry_point=None,
116116
source_dir=None,
117117
dependencies=None,
118+
**kwargs
118119
):
119120
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an
120121
``Endpoint``.
@@ -139,6 +140,7 @@ def create_model(
139140
dependencies (list[str]): A list of paths to directories (absolute or relative) with
140141
any additional libraries that will be exported to the container.
141142
If not specified, the dependencies from training are used.
143+
**kwargs: Additional kwargs passed to the PyTorchModel constructor.
142144
143145
Returns:
144146
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel``

src/sagemaker/rl/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def create_model(
163163
entry_point=None,
164164
source_dir=None,
165165
dependencies=None,
166+
**kwargs
166167
):
167168
"""Create a SageMaker ``RLEstimatorModel`` object that can be deployed
168169
to an Endpoint.
@@ -189,6 +190,7 @@ def create_model(
189190
folders will be copied to SageMaker in the same folder where the
190191
entry_point is copied. If the ```source_dir``` points to S3,
191192
code will be uploaded and the S3 location will be used instead.
193+
**kwargs: Additional kwargs passed to the FrameworkModel constructor.
192194
193195
Returns:
194196
sagemaker.model.FrameworkModel: Depending on input parameters returns

src/sagemaker/tensorflow/estimator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ def create_model(
504504
entry_point=None,
505505
source_dir=None,
506506
dependencies=None,
507+
**kwargs
507508
):
508509
"""Create a ``Model`` object that can be used for creating SageMaker model entities,
509510
deploying to a SageMaker endpoint, or starting SageMaker Batch Transform jobs.
@@ -537,6 +538,8 @@ def create_model(
537538
If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is
538539
set to ``None``.
539540
If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
541+
**kwargs: Additional kwargs passed to ``sagemaker.tensorflow.serving.Model`` constructor
542+
and ``sagemaker.tensorflow.model.TensorFlowModel`` constructor.
540543
541544
Returns:
542545
sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A
@@ -552,6 +555,7 @@ def create_model(
552555
entry_point=entry_point,
553556
source_dir=source_dir,
554557
dependencies=dependencies,
558+
**kwargs
555559
)
556560

557561
return self._create_default_model(
@@ -561,6 +565,7 @@ def create_model(
561565
entry_point=entry_point,
562566
source_dir=source_dir,
563567
dependencies=dependencies,
568+
**kwargs
564569
)
565570

566571
def _create_tfs_model(
@@ -570,6 +575,7 @@ def _create_tfs_model(
570575
entry_point=None,
571576
source_dir=None,
572577
dependencies=None,
578+
**kwargs
573579
):
574580
"""Placeholder docstring"""
575581
return Model(
@@ -585,6 +591,7 @@ def _create_tfs_model(
585591
source_dir=source_dir,
586592
dependencies=dependencies,
587593
enable_network_isolation=self.enable_network_isolation(),
594+
**kwargs
588595
)
589596

590597
def _create_default_model(
@@ -595,6 +602,7 @@ def _create_default_model(
595602
entry_point=None,
596603
source_dir=None,
597604
dependencies=None,
605+
**kwargs
598606
):
599607
"""Placeholder docstring"""
600608
return TensorFlowModel(
@@ -615,6 +623,7 @@ def _create_default_model(
615623
vpc_config=self.get_vpc_config(vpc_config_override),
616624
dependencies=dependencies or self.dependencies,
617625
enable_network_isolation=self.enable_network_isolation(),
626+
**kwargs
618627
)
619628

620629
def hyperparameters(self):

0 commit comments

Comments
 (0)