@@ -82,6 +82,10 @@ def __init__(
8282 model_channel_name = "model" ,
8383 metric_definitions = None ,
8484 encrypt_inter_container_traffic = False ,
85+ train_use_spot_instances = False ,
86+ train_max_wait = None ,
87+ checkpoint_s3_uri = None ,
88+ checkpoint_local_path = None ,
8589 ):
8690 """Initialize an ``EstimatorBase`` instance.
8791
@@ -157,6 +161,28 @@ def __init__(
157161 encrypt_inter_container_traffic (bool): Specifies whether traffic
158162 between training containers is encrypted for the training job
159163 (default: ``False``).
164+ train_use_spot_instances (bool): Specifies whether to use SageMaker
165+ Managed Spot instances for training. If enabled then the
166+ `train_max_wait` arg should also be set.
167+
168+ More information:
169+ https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
170+ (default: ``False``).
171+ train_max_wait (int): Timeout in seconds waiting for spot training
172+ instances (default: None). After this amount of time Amazon
173+ SageMaker will stop waiting for Spot instances to become
174+ available (default: ``None``).
175+ checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
176+ that the algorithm persists (if any) during training. (default:
177+ ``None``).
178+ checkpoint_local_path (str): The local path that the algorithm
179+ writes its checkpoints to. SageMaker will persist all files
180+ under this path to `checkpoint_s3_uri` continually during
181+ training. On job startup the reverse happens - data from the
182+ s3 location is downloaded to this path before the algorithm is
183+ started. If the path is unset then SageMaker assumes the
184+ checkpoints will be provided under `/opt/ml/checkpoints/`.
185+ (default: ``None``).
160186 """
161187 self .role = role
162188 self .train_instance_count = train_instance_count
@@ -199,6 +225,10 @@ def __init__(
199225 self .security_group_ids = security_group_ids
200226
201227 self .encrypt_inter_container_traffic = encrypt_inter_container_traffic
228+ self .train_use_spot_instances = train_use_spot_instances
229+ self .train_max_wait = train_max_wait
230+ self .checkpoint_s3_uri = checkpoint_s3_uri
231+ self .checkpoint_local_path = checkpoint_local_path
202232
203233 @abstractmethod
204234 def train_image (self ):
@@ -795,10 +825,35 @@ def start_new(cls, estimator, inputs):
795825 else :
796826 train_args ["image" ] = estimator .train_image ()
797827
828+ cls ._add_spot_checkpoint_args (local_mode , estimator , train_args )
829+
798830 estimator .sagemaker_session .train (** train_args )
799831
800832 return cls (estimator .sagemaker_session , estimator ._current_job_name )
801833
834+ @classmethod
835+ def _add_spot_checkpoint_args (cls , local_mode , estimator , train_args ):
836+ """
837+ Args:
838+ local_mode:
839+ estimator:
840+ train_args:
841+ """
842+ if estimator .train_use_spot_instances :
843+ if local_mode :
844+ raise ValueError ("Spot training is not supported in local mode." )
845+ train_args ["train_use_spot_instances" ] = True
846+
847+ if estimator .checkpoint_s3_uri :
848+ if local_mode :
849+ raise ValueError ("Setting checkpoint_s3_uri is not supported in local mode." )
850+ train_args ["checkpoint_s3_uri" ] = estimator .checkpoint_s3_uri
851+
852+ if estimator .checkpoint_local_path :
853+ if local_mode :
854+ raise ValueError ("Setting checkpoint_local_path is not supported in local mode." )
855+ train_args ["checkpoint_local_path" ] = estimator .checkpoint_local_path
856+
802857 @classmethod
803858 def _is_local_channel (cls , input_uri ):
804859 """
@@ -845,6 +900,10 @@ def __init__(
845900 model_channel_name = "model" ,
846901 metric_definitions = None ,
847902 encrypt_inter_container_traffic = False ,
903+ train_use_spot_instances = False ,
904+ train_max_wait = None ,
905+ checkpoint_s3_uri = None ,
906+ checkpoint_local_path = None ,
848907 ):
849908 """Initialize an ``Estimator`` instance.
850909
@@ -926,6 +985,28 @@ def __init__(
926985 encrypt_inter_container_traffic (bool): Specifies whether traffic
927986 between training containers is encrypted for the training job
928987 (default: ``False``).
988+ train_use_spot_instances (bool): Specifies whether to use SageMaker
989+ Managed Spot instances for training. If enabled then the
990+ `train_max_wait` arg should also be set.
991+
992+ More information:
993+ https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
994+ (default: ``False``).
995+ train_max_wait (int): Timeout in seconds waiting for spot training
996+ instances (default: None). After this amount of time Amazon
997+ SageMaker will stop waiting for Spot instances to become
998+ available (default: ``None``).
999+ checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
1000+ that the algorithm persists (if any) during training. (default:
1001+ ``None``).
1002+ checkpoint_local_path (str): The local path that the algorithm
1003+ writes its checkpoints to. SageMaker will persist all files
1004+ under this path to `checkpoint_s3_uri` continually during
1005+ training. On job startup the reverse happens - data from the
1006+ s3 location is downloaded to this path before the algorithm is
1007+ started. If the path is unset then SageMaker assumes the
1008+ checkpoints will be provided under `/opt/ml/checkpoints/`.
1009+ (default: ``None``).
9291010 """
9301011 self .image_name = image_name
9311012 self .hyperparam_dict = hyperparameters .copy () if hyperparameters else {}
@@ -948,6 +1029,10 @@ def __init__(
9481029 model_channel_name = model_channel_name ,
9491030 metric_definitions = metric_definitions ,
9501031 encrypt_inter_container_traffic = encrypt_inter_container_traffic ,
1032+ train_use_spot_instances = train_use_spot_instances ,
1033+ train_max_wait = train_max_wait ,
1034+ checkpoint_s3_uri = checkpoint_s3_uri ,
1035+ checkpoint_local_path = checkpoint_local_path ,
9511036 )
9521037
9531038 def train_image (self ):
0 commit comments