@@ -185,6 +185,7 @@ def __init__(
185185 disable_output_compression : bool = False ,
186186 enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
187187 enable_session_tag_chaining : Optional [Union [bool , PipelineVariable ]] = None ,
188+ training_plan : Optional [Union [str , PipelineVariable ]] = None ,
188189 ** kwargs ,
189190 ):
190191 """Initialize an ``EstimatorBase`` instance.
@@ -554,6 +555,8 @@ def __init__(
554555 Specifies whether RemoteDebug is enabled for the training job.
555556 enable_session_tag_chaining (bool or PipelineVariable): Optional.
556557 Specifies whether SessionTagChaining is enabled for the training job.
558+ training_plan (str or PipelineVariable): Optional.
559+ Specifies which training plan arn to use for the training job
557560 """
558561 instance_count = renamed_kwargs (
559562 "train_instance_count" , "instance_count" , instance_count , kwargs
@@ -762,8 +765,7 @@ def __init__(
762765
763766 self .tensorboard_output_config = tensorboard_output_config
764767
765- self .debugger_rule_configs = None
766- self .collection_configs = None
768+ self .debugger_rule_configs , self .collection_configs = None , None
767769
768770 self .enable_sagemaker_metrics = enable_sagemaker_metrics
769771
@@ -774,6 +776,7 @@ def __init__(
774776 sagemaker_session = self .sagemaker_session ,
775777 )
776778
779+ self .profiler_rule_configs , self .profiler_rules = None , None
777780 self .profiler_config = profiler_config
778781 self .disable_profiler = resolve_value_from_config (
779782 direct_input = disable_profiler ,
@@ -796,8 +799,6 @@ def __init__(
796799 ) or _instance_type_supports_profiler (self .instance_type ):
797800 self .disable_profiler = True
798801
799- self .profiler_rule_configs = None
800- self .profiler_rules = None
801802 self .debugger_rules = None
802803 self .disable_output_compression = disable_output_compression
803804 validate_source_code_input_against_pipeline_variables (
@@ -807,6 +808,8 @@ def __init__(
807808 enable_network_isolation = self ._enable_network_isolation ,
808809 )
809810
811+ self .training_plan = training_plan
812+
810813 # Internal flag
811814 self ._is_output_path_set_from_default_bucket_and_prefix = False
812815
@@ -1960,6 +1963,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
19601963 "KeepAlivePeriodInSeconds"
19611964 ]
19621965
1966+ if "TrainingPlanArn" in job_details ["ResourceConfig" ]:
1967+ init_params ["training_plan" ] = job_details ["ResourceConfig" ]["TrainingPlanArn" ]
1968+
19631969 has_hps = "HyperParameters" in job_details
19641970 init_params ["hyperparameters" ] = job_details ["HyperParameters" ] if has_hps else {}
19651971
@@ -2840,6 +2846,7 @@ def __init__(
28402846 enable_infra_check : Optional [Union [bool , PipelineVariable ]] = None ,
28412847 enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
28422848 enable_session_tag_chaining : Optional [Union [bool , PipelineVariable ]] = None ,
2849+ training_plan : Optional [Union [str , PipelineVariable ]] = None ,
28432850 ** kwargs ,
28442851 ):
28452852 """Initialize an ``Estimator`` instance.
@@ -3205,6 +3212,8 @@ def __init__(
32053212 Specifies whether RemoteDebug is enabled for the training job
32063213 enable_session_tag_chaining (bool or PipelineVariable): Optional.
32073214 Specifies whether SessionTagChaining is enabled for the training job
3215+ training_plan (str or PipelineVariable): Optional.
3216+ Specifies which training plan arn to use for the training job
32083217 """
32093218 self .image_uri = image_uri
32103219 self ._hyperparameters = hyperparameters .copy () if hyperparameters else {}
@@ -3258,6 +3267,7 @@ def __init__(
32583267 disable_output_compression = disable_output_compression ,
32593268 enable_remote_debug = enable_remote_debug ,
32603269 enable_session_tag_chaining = enable_session_tag_chaining ,
3270+ training_plan = training_plan ,
32613271 ** kwargs ,
32623272 )
32633273
0 commit comments