@@ -181,6 +181,7 @@ def __init__(
181181 container_arguments : Optional [List [str ]] = None ,
182182 disable_output_compression : bool = False ,
183183 enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
184+ enable_session_tag_chaining : Optional [Union [bool , PipelineVariable ]] = None ,
184185 ** kwargs ,
185186 ):
186187 """Initialize an ``EstimatorBase`` instance.
@@ -544,7 +545,9 @@ def __init__(
544545 enable_infra_check (bool or PipelineVariable): Optional.
545546 Specifies whether it is running Sagemaker built-in infra check jobs.
546547 enable_remote_debug (bool or PipelineVariable): Optional.
547- Specifies whether RemoteDebug is enabled for the training job
548+ Specifies whether RemoteDebug is enabled for the training job.
549+ enable_session_tag_chaining (bool or PipelineVariable): Optional.
550+ Specifies whether SessionTagChaining is enabled for the training job.
548551 """
549552 instance_count = renamed_kwargs (
550553 "train_instance_count" , "instance_count" , instance_count , kwargs
@@ -785,6 +788,8 @@ def __init__(
785788
786789 self ._enable_remote_debug = enable_remote_debug
787790
791+ self ._enable_session_tag_chaining = enable_session_tag_chaining
792+
788793 @abstractmethod
789794 def training_image_uri (self ):
790795 """Return the Docker image to use for training.
@@ -2318,6 +2323,14 @@ def get_remote_debug_config(self):
23182323 else {"EnableRemoteDebug" : self ._enable_remote_debug }
23192324 )
23202325
2326+ def get_session_chaining_config (self ):
2327+ """dict: Return the configuration of SessionChaining"""
2328+ return (
2329+ None
2330+ if self ._enable_session_tag_chaining is None
2331+ else {"EnableSessionTagChaining" : self ._enable_session_tag_chaining }
2332+ )
2333+
23212334 def enable_remote_debug (self ):
23222335 """Enable remote debug for a training job."""
23232336 self ._update_remote_debug (True )
@@ -2574,6 +2587,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25742587 if estimator .get_remote_debug_config () is not None :
25752588 train_args ["remote_debug_config" ] = estimator .get_remote_debug_config ()
25762589
2590+ if estimator .get_session_chaining_config () is not None :
2591+ train_args ["session_chaining_config" ] = estimator .get_session_chaining_config ()
2592+
25772593 return train_args
25782594
25792595 @classmethod
@@ -2766,6 +2782,7 @@ def __init__(
27662782 disable_output_compression : bool = False ,
27672783 enable_infra_check : Optional [Union [bool , PipelineVariable ]] = None ,
27682784 enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
2785+ enable_session_tag_chaining : Optional [Union [bool , PipelineVariable ]] = None ,
27692786 ** kwargs ,
27702787 ):
27712788 """Initialize an ``Estimator`` instance.
@@ -3129,6 +3146,8 @@ def __init__(
31293146 Specifies whether it is running Sagemaker built-in infra check jobs.
31303147 enable_remote_debug (bool or PipelineVariable): Optional.
31313148 Specifies whether RemoteDebug is enabled for the training job
3149+ enable_session_tag_chaining (bool or PipelineVariable): Optional.
3150+ Specifies whether SessionTagChaining is enabled for the training job
31323151 """
31333152 self .image_uri = image_uri
31343153 self ._hyperparameters = hyperparameters .copy () if hyperparameters else {}
@@ -3181,6 +3200,7 @@ def __init__(
31813200 container_arguments = container_arguments ,
31823201 disable_output_compression = disable_output_compression ,
31833202 enable_remote_debug = enable_remote_debug ,
3203+ enable_session_tag_chaining = enable_session_tag_chaining ,
31843204 ** kwargs ,
31853205 )
31863206
0 commit comments