@@ -191,6 +191,8 @@ def __init__(
191191 volume_size : int = 30 ,
192192 encrypt_inter_container_traffic : bool = None ,
193193 spark_config : SparkConfig = None ,
194+ use_spot_instances = False ,
195+ max_wait_time_in_seconds = None ,
194196 ):
195197 """Initialize a _JobSettings instance which configures the remote job.
196198
@@ -353,6 +355,14 @@ def __init__(
353355 Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
354356 will be used for training. Note that ``image_uri`` can not be specified at the
355357 same time otherwise a ``ValueError`` is thrown. Defaults to ``None``.
358+
359+ use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for
360+ training. If enabled then the ``max_wait`` arg should also be set.
361+ Defaults to ``False``.
362+
363+ max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
364+ After this amount of time Amazon SageMaker will stop waiting for managed spot
365+ training job to complete. Defaults to ``None``.
356366 """
357367 self .sagemaker_session = sagemaker_session or Session ()
358368 self .environment_variables = resolve_value_from_config (
@@ -439,6 +449,8 @@ def __init__(
439449 self .max_retry_attempts = max_retry_attempts
440450 self .keep_alive_period_in_seconds = keep_alive_period_in_seconds
441451 self .spark_config = spark_config
452+ self .use_spot_instances = use_spot_instances
453+ self .max_wait_time_in_seconds = max_wait_time_in_seconds
442454 self .job_conda_env = resolve_value_from_config (
443455 direct_input = job_conda_env ,
444456 config_path = REMOTE_FUNCTION_JOB_CONDA_ENV ,
@@ -648,12 +660,16 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
648660
649661 stored_function .save (func , * func_args , ** func_kwargs )
650662
663+ stopping_condition = {
664+ "MaxRuntimeInSeconds" : job_settings .max_runtime_in_seconds ,
665+ }
666+ if job_settings .max_wait_time_in_seconds is not None :
667+ stopping_condition ["MaxWaitTimeInSeconds" ] = job_settings .max_wait_time_in_seconds
668+
651669 request_dict = dict (
652670 TrainingJobName = job_name ,
653671 RoleArn = job_settings .role ,
654- StoppingCondition = {
655- "MaxRuntimeInSeconds" : job_settings .max_runtime_in_seconds ,
656- },
672+ StoppingCondition = stopping_condition ,
657673 RetryStrategy = {"MaximumRetryAttempts" : job_settings .max_retry_attempts },
658674 )
659675
@@ -742,6 +758,8 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
742758 if job_settings .vpc_config :
743759 request_dict ["VpcConfig" ] = job_settings .vpc_config
744760
761+ request_dict ["EnableManagedSpotTraining" ] = job_settings .use_spot_instances
762+
745763 request_dict ["Environment" ] = job_settings .environment_variables
746764
747765 extended_request = _extend_spark_config_to_request (request_dict , job_settings , s3_base_uri )
0 commit comments