@@ -979,7 +979,7 @@ def from_recipe(
979979
980980 def with_tensorboard_output_config (
981981 self , tensorboard_output_config : TensorBoardOutputConfig
982- ) -> "ModelTrainer" :
982+ ) -> "ModelTrainer" : # noqa: D412
983983 """Set the TensorBoard output configuration.
984984
985985 Example:
@@ -998,15 +998,14 @@ def with_tensorboard_output_config(
998998 ...
999999 ).with_tensorboard_output_config(tensorboard_output_config)
10001000
1001-
10021001 Args:
10031002 tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig):
10041003 The TensorBoard output configuration.
10051004 """
10061005 self ._tensorboard_output_config = tensorboard_output_config
10071006 return self
10081007
1009- def with_retry_strategy (self , retry_strategy : RetryStrategy ) -> "ModelTrainer" :
1008+ def with_retry_strategy (self , retry_strategy : RetryStrategy ) -> "ModelTrainer" : # noqa: D412
10101009 """Set the retry strategy for the training job.
10111010
10121011 Example:
@@ -1031,7 +1030,9 @@ def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer":
10311030 self ._retry_strategy = retry_strategy
10321031 return self
10331032
1034- def with_infra_check_config (self , infra_check_config : InfraCheckConfig ) -> "ModelTrainer" :
1033+ def with_infra_check_config (
1034+ self , infra_check_config : InfraCheckConfig
1035+ ) -> "ModelTrainer" : # noqa: D412
10351036 """Set the infra check configuration for the training job.
10361037
10371038 Example:
@@ -1058,7 +1059,7 @@ def with_infra_check_config(self, infra_check_config: InfraCheckConfig) -> "Mode
10581059
10591060 def with_session_chaining_config (
10601061 self , session_chaining_config : SessionChainingConfig
1061- ) -> "ModelTrainer" :
1062+ ) -> "ModelTrainer" : # noqa: D412
10621063 """Set the session chaining configuration for the training job.
10631064
10641065 Example:
@@ -1083,17 +1084,35 @@ def with_session_chaining_config(
10831084 self ._session_chaining_config = session_chaining_config
10841085 return self
10851086
1086- def with_remote_debug_config (self , remote_debug_config : RemoteDebugConfig ) -> "ModelTrainer" :
1087+ def with_remote_debug_config (
1088+ self , remote_debug_config : RemoteDebugConfig
1089+ ) -> "ModelTrainer" : # noqa: D412
10871090 """Set the remote debug configuration for the training job.
10881091
1092+ Example:
1093+
1094+ .. code:: python
1095+
1096+ from sagemaker.modules.train import ModelTrainer
1097+ from sagemaker.modules.configs import RemoteDebugConfig
1098+
1099+ remote_debug_config = RemoteDebugConfig(
1100+ enable_remote_debug=True,
1101+ )
1102+ model_trainer = ModelTrainer(
1103+ ...
1104+ ).with_remote_debug_config(remote_debug_config)
1105+
10891106 Args:
10901107 remote_debug_config (RemoteDebugConfig):
10911108 The remote debug configuration for the training job.
10921109 """
10931110 self ._remote_debug_config = remote_debug_config
10941111 return self
10951112
1096- def with_metric_definitions (self , metric_definitions : List [MetricDefinition ]) -> "ModelTrainer" :
1113+ def with_metric_definitions (
1114+ self , metric_definitions : List [MetricDefinition ]
1115+ ) -> "ModelTrainer" : # noqa: D412
10971116 """Set the metric definitions for the training job.
10981117
10991118 Example:
@@ -1106,7 +1125,7 @@ def with_metric_definitions(self, metric_definitions: List[MetricDefinition]) ->
11061125 metric_definitions = [
11071126 MetricDefinition(
11081127 name="loss",
1109- regex="Loss: (.*?); ",
1128+ regex="Loss: (.*?)",
11101129 )
11111130 ]
11121131
0 commit comments