Skip to content

Commit 3624364

Browse files
authored
feature: add spot instance support for AlgorithmEstimator (#1672)
1 parent 31ac368 commit 3624364

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

src/sagemaker/algorithm.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __init__(
5353
model_channel_name="model",
5454
metric_definitions=None,
5555
encrypt_inter_container_traffic=False,
56+
train_use_spot_instances=False,
57+
train_max_wait=None,
5658
**kwargs # pylint: disable=W0613
5759
):
5860
"""Initialize an ``AlgorithmEstimator`` instance.
@@ -125,6 +127,17 @@ def __init__(
125127
expression used to extract the metric from the logs.
126128
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
127129
containers is encrypted for the training job (default: ``False``).
130+
train_use_spot_instances (bool): Specifies whether to use SageMaker
131+
Managed Spot instances for training. If enabled then the
132+
`train_max_wait` arg should also be set.
133+
134+
More information:
135+
https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
136+
(default: ``False``).
137+
train_max_wait (int): Timeout in seconds waiting for spot training
138+
instances (default: None). After this amount of time Amazon
139+
SageMaker will stop waiting for Spot instances to become
140+
available (default: ``None``).
128141
**kwargs: Additional kwargs. This is unused. It's only added for AlgorithmEstimator
129142
to ignore the irrelevant arguments.
130143
"""
@@ -148,6 +161,8 @@ def __init__(
148161
model_channel_name=model_channel_name,
149162
metric_definitions=metric_definitions,
150163
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
164+
train_use_spot_instances=train_use_spot_instances,
165+
train_max_wait=train_max_wait,
151166
)
152167

153168
self.algorithm_spec = self.sagemaker_session.sagemaker_client.describe_algorithm(

tests/unit/test_algorithm.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,3 +1015,18 @@ def test_algorithm_attach_from_hyperparameter_tuning():
10151015
assert estimator.train_volume_size == train_volume_size
10161016
assert estimator.input_mode == input_mode
10171017
assert estimator.sagemaker_session == session
1018+
1019+
1020+
@patch("sagemaker.Session")
1021+
def test_algorithm_supported_with_spot_instances(session):
1022+
session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE)
1023+
1024+
assert AlgorithmEstimator(
1025+
algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees",
1026+
role="SageMakerRole",
1027+
train_instance_type="ml.m4.xlarge",
1028+
train_instance_count=1,
1029+
train_use_spot_instances=True,
1030+
train_max_wait=500,
1031+
sagemaker_session=session,
1032+
)

0 commit comments

Comments
 (0)