@@ -53,6 +53,8 @@ def __init__(
5353 model_channel_name = "model" ,
5454 metric_definitions = None ,
5555 encrypt_inter_container_traffic = False ,
56+ train_use_spot_instances = False ,
57+ train_max_wait = None ,
5658 ** kwargs # pylint: disable=W0613
5759 ):
5860 """Initialize an ``AlgorithmEstimator`` instance.
@@ -125,6 +127,17 @@ def __init__(
125127 expression used to extract the metric from the logs.
126128 encrypt_inter_container_traffic (bool): Specifies whether traffic between training
127129 containers is encrypted for the training job (default: ``False``).
130+ train_use_spot_instances (bool): Specifies whether to use SageMaker
131+ Managed Spot instances for training. If enabled then the
132+ `train_max_wait` arg should also be set.
133+
134+ More information:
135+ https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
136+ (default: ``False``).
137+ train_max_wait (int): Timeout in seconds waiting for spot training
138+ instances (default: None). After this amount of time Amazon
139+ SageMaker will stop waiting for Spot instances to become
140+ available (default: ``None``).
128141 **kwargs: Additional kwargs. This is unused. It's only added for AlgorithmEstimator
129142 to ignore the irrelevant arguments.
130143 """
@@ -148,6 +161,8 @@ def __init__(
148161 model_channel_name = model_channel_name ,
149162 metric_definitions = metric_definitions ,
150163 encrypt_inter_container_traffic = encrypt_inter_container_traffic ,
164+ train_use_spot_instances = train_use_spot_instances ,
165+ train_max_wait = train_max_wait ,
151166 )
152167
153168 self .algorithm_spec = self .sagemaker_session .sagemaker_client .describe_algorithm (
0 commit comments