@@ -116,6 +116,7 @@ def __init__(
116116 role : str ,
117117 instance_count : Optional [Union [int , PipelineVariable ]] = None ,
118118 instance_type : Optional [Union [str , PipelineVariable ]] = None ,
119+ keep_alive_period_in_seconds : Optional [Union [int , PipelineVariable ]] = None ,
119120 volume_size : Union [int , PipelineVariable ] = 30 ,
120121 volume_kms_key : Optional [Union [str , PipelineVariable ]] = None ,
121122 max_run : Union [int , PipelineVariable ] = 24 * 60 * 60 ,
@@ -167,6 +168,9 @@ def __init__(
167168 instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
168169 for example, ``'ml.c4.xlarge'``. Required if instance_groups is
169170 not set.
171+ keep_alive_period_in_seconds (int): The duration of time in seconds
172+ to retain configured resources in a warm pool for subsequent
173+ training jobs (default: None).
170174 volume_size (int or PipelineVariable): Size in GB of the storage volume to use for
171175 storing input and output data during training (default: 30).
172176
@@ -510,6 +514,7 @@ def __init__(
510514 self .role = role
511515 self .instance_count = instance_count
512516 self .instance_type = instance_type
517+ self .keep_alive_period_in_seconds = keep_alive_period_in_seconds
513518 self .instance_groups = instance_groups
514519 self .volume_size = volume_size
515520 self .volume_kms_key = volume_kms_key
@@ -1578,6 +1583,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
15781583 if "EnableNetworkIsolation" in job_details :
15791584 init_params ["enable_network_isolation" ] = job_details ["EnableNetworkIsolation" ]
15801585
1586+ if "KeepAlivePeriodInSeconds" in job_details ["ResourceConfig" ]:
1587+ init_params ["keep_alive_period_in_seconds" ] = job_details ["ResourceConfig" ][
1588+ "keepAlivePeriodInSeconds"
1589+ ]
1590+
15811591 has_hps = "HyperParameters" in job_details
15821592 init_params ["hyperparameters" ] = job_details ["HyperParameters" ] if has_hps else {}
15831593
@@ -2126,7 +2136,9 @@ def _is_local_channel(cls, input_uri):
21262136 return isinstance (input_uri , string_types ) and input_uri .startswith ("file://" )
21272137
21282138 @classmethod
2129- def update (cls , estimator , profiler_rule_configs = None , profiler_config = None ):
2139+ def update (
2140+ cls , estimator , profiler_rule_configs = None , profiler_config = None , resource_config = None
2141+ ):
21302142 """Update a running Amazon SageMaker training job.
21312143
21322144 Args:
@@ -2135,18 +2147,23 @@ def update(cls, estimator, profiler_rule_configs=None, profiler_config=None):
21352147 updated in the training job. (default: None).
21362148 profiler_config (dict): Configuration for how profiling information is emitted with
21372149 SageMaker Debugger. (default: None).
2150+ resource_config (dict): Configuration of the resources for the training job. You can
2151+ update the keep-alive period if the warm pool status is `Available`. No other fields
2152+ can be updated. (default: None).
21382153
21392154 Returns:
21402155 sagemaker.estimator._TrainingJob: Constructed object that captures
21412156 all information about the updated training job.
21422157 """
2143- update_args = cls ._get_update_args (estimator , profiler_rule_configs , profiler_config )
2158+ update_args = cls ._get_update_args (
2159+ estimator , profiler_rule_configs , profiler_config , resource_config
2160+ )
21442161 estimator .sagemaker_session .update_training_job (** update_args )
21452162
21462163 return estimator .latest_training_job
21472164
21482165 @classmethod
2149- def _get_update_args (cls , estimator , profiler_rule_configs , profiler_config ):
2166+ def _get_update_args (cls , estimator , profiler_rule_configs , profiler_config , resource_config ):
21502167 """Constructs a dict of arguments for updating an Amazon SageMaker training job.
21512168
21522169 Args:
@@ -2156,13 +2173,17 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config):
21562173 updated in the training job. (default: None).
21572174 profiler_config (dict): Configuration for how profiling information is emitted with
21582175 SageMaker Debugger. (default: None).
2176+ resource_config (dict): Configuration of the resources for the training job. You can
2177+ update the keep-alive period if the warm pool status is `Available`. No other fields
2178+ can be updated. (default: None).
21592179
21602180 Returns:
21612181 Dict: dict for `sagemaker.session.Session.update_training_job` method
21622182 """
21632183 update_args = {"job_name" : estimator .latest_training_job .name }
21642184 update_args .update (build_dict ("profiler_rule_configs" , profiler_rule_configs ))
21652185 update_args .update (build_dict ("profiler_config" , profiler_config ))
2186+ update_args .update (build_dict ("resource_config" , resource_config ))
21662187
21672188 return update_args
21682189
@@ -2218,6 +2239,7 @@ def __init__(
22182239 role : str ,
22192240 instance_count : Optional [Union [int , PipelineVariable ]] = None ,
22202241 instance_type : Optional [Union [str , PipelineVariable ]] = None ,
2242+ keep_alive_period_in_seconds : Optional [Union [int , PipelineVariable ]] = None ,
22212243 volume_size : Union [int , PipelineVariable ] = 30 ,
22222244 volume_kms_key : Optional [Union [str , PipelineVariable ]] = None ,
22232245 max_run : Union [int , PipelineVariable ] = 24 * 60 * 60 ,
@@ -2270,6 +2292,9 @@ def __init__(
22702292 instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
22712293 for example, ``'ml.c4.xlarge'``. Required if instance_groups is
22722294 not set.
2295+ keep_alive_period_in_seconds (int): The duration of time in seconds
2296+ to retain configured resources in a warm pool for subsequent
2297+ training jobs (default: None).
22732298 volume_size (int or PipelineVariable): Size in GB of the storage volume to use for
22742299 storing input and output data during training (default: 30).
22752300
@@ -2591,6 +2616,7 @@ def __init__(
25912616 role ,
25922617 instance_count ,
25932618 instance_type ,
2619+ keep_alive_period_in_seconds ,
25942620 volume_size ,
25952621 volume_kms_key ,
25962622 max_run ,
0 commit comments