@@ -124,6 +124,7 @@ def __init__(
124124 profiler_config = None ,
125125 disable_profiler = False ,
126126 environment = None ,
127+ max_retry_attempts = None ,
127128 ** kwargs ,
128129 ):
129130 """Initialize an ``EstimatorBase`` instance.
@@ -269,6 +270,13 @@ def __init__(
269270 will be disabled (default: ``False``).
270271 environment (dict[str, str]) : Environment variables to be set for
271272 use during training job (default: ``None``)
273+ max_retry_attempts (int): The number of times to move a job to the STARTING status.
274+ You can specify between 1 and 30 attempts.
275+ If the value of attempts is greater than zero,
276+ the job is retried on InternalServerFailure
277+ the same number of attempts as the value.
278+ You can cap the total duration for your job by setting ``max_wait`` and ``max_run``
279+ (default: ``None``)
272280
273281 """
274282 instance_count = renamed_kwargs (
@@ -357,6 +365,8 @@ def __init__(
357365
358366 self .environment = environment
359367
368+ self .max_retry_attempts = max_retry_attempts
369+
360370 if not _region_supports_profiler (self .sagemaker_session .boto_region_name ):
361371 self .disable_profiler = True
362372
@@ -1114,6 +1124,13 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
11141124 if max_wait :
11151125 init_params ["max_wait" ] = max_wait
11161126
1127+ if job_details .get ("RetryStrategy" , False ):
1128+ init_params ["max_retry_attempts" ] = job_details .get ("RetryStrategy" , {}).get (
1129+ "MaximumRetryAttempts"
1130+ )
1131+ max_wait = job_details .get ("StoppingCondition" , {}).get ("MaxWaitTimeInSeconds" )
1132+ if max_wait :
1133+ init_params ["max_wait" ] = max_wait
11171134 return init_params
11181135
11191136 def transformer (
@@ -1489,6 +1506,11 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
14891506 if estimator .enable_network_isolation ():
14901507 train_args ["enable_network_isolation" ] = True
14911508
1509+ if estimator .max_retry_attempts is not None :
1510+ train_args ["retry_strategy" ] = {"MaximumRetryAttempts" : estimator .max_retry_attempts }
1511+ else :
1512+ train_args ["retry_strategy" ] = None
1513+
14921514 if estimator .encrypt_inter_container_traffic :
14931515 train_args ["encrypt_inter_container_traffic" ] = True
14941516
@@ -1666,6 +1688,7 @@ def __init__(
16661688 profiler_config = None ,
16671689 disable_profiler = False ,
16681690 environment = None ,
1691+ max_retry_attempts = None ,
16691692 ** kwargs ,
16701693 ):
16711694 """Initialize an ``Estimator`` instance.
@@ -1816,6 +1839,13 @@ def __init__(
18161839 will be disabled (default: ``False``).
18171840 environment (dict[str, str]) : Environment variables to be set for
18181841 use during training job (default: ``None``)
1842+ max_retry_attempts (int): The number of times to move a job to the STARTING status.
1843+ You can specify between 1 and 30 attempts.
1844+ If the value of attempts is greater than zero,
1845+ the job is retried on InternalServerFailure
1846+ the same number of attempts as the value.
1847+ You can cap the total duration for your job by setting ``max_wait`` and ``max_run``
1848+ (default: ``None``)
18191849 """
18201850 self .image_uri = image_uri
18211851 self .hyperparam_dict = hyperparameters .copy () if hyperparameters else {}
@@ -1850,6 +1880,7 @@ def __init__(
18501880 profiler_config = profiler_config ,
18511881 disable_profiler = disable_profiler ,
18521882 environment = environment ,
1883+ max_retry_attempts = max_retry_attempts ,
18531884 ** kwargs ,
18541885 )
18551886
0 commit comments