@@ -53,6 +53,8 @@ def __init__(
53
53
model_channel_name = "model" ,
54
54
metric_definitions = None ,
55
55
encrypt_inter_container_traffic = False ,
56
+ train_use_spot_instances = False ,
57
+ train_max_wait = None ,
56
58
** kwargs # pylint: disable=W0613
57
59
):
58
60
"""Initialize an ``AlgorithmEstimator`` instance.
@@ -125,6 +127,17 @@ def __init__(
125
127
expression used to extract the metric from the logs.
126
128
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
127
129
containers is encrypted for the training job (default: ``False``).
130
+ train_use_spot_instances (bool): Specifies whether to use SageMaker
131
+ Managed Spot instances for training. If enabled then the
132
+ `train_max_wait` arg should also be set.
133
+
134
+ More information:
135
+ https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
136
+ (default: ``False``).
137
+ train_max_wait (int): Timeout in seconds waiting for spot training
138
+ instances (default: None). After this amount of time Amazon
139
+ SageMaker will stop waiting for Spot instances to become
140
+ available (default: ``None``).
128
141
**kwargs: Additional kwargs. This is unused. It's only added for AlgorithmEstimator
129
142
to ignore the irrelevant arguments.
130
143
"""
@@ -148,6 +161,8 @@ def __init__(
148
161
model_channel_name = model_channel_name ,
149
162
metric_definitions = metric_definitions ,
150
163
encrypt_inter_container_traffic = encrypt_inter_container_traffic ,
164
+ train_use_spot_instances = train_use_spot_instances ,
165
+ train_max_wait = train_max_wait ,
151
166
)
152
167
153
168
self .algorithm_spec = self .sagemaker_session .sagemaker_client .describe_algorithm (
0 commit comments