@@ -979,7 +979,7 @@ def from_recipe(
979
979
980
980
def with_tensorboard_output_config (
981
981
self , tensorboard_output_config : TensorBoardOutputConfig
982
- ) -> "ModelTrainer" :
982
+ ) -> "ModelTrainer" : # noqa: D412
983
983
"""Set the TensorBoard output configuration.
984
984
985
985
Example:
@@ -998,15 +998,14 @@ def with_tensorboard_output_config(
998
998
...
999
999
).with_tensorboard_output_config(tensorboard_output_config)
1000
1000
1001
-
1002
1001
Args:
1003
1002
tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig):
1004
1003
The TensorBoard output configuration.
1005
1004
"""
1006
1005
self ._tensorboard_output_config = tensorboard_output_config
1007
1006
return self
1008
1007
1009
- def with_retry_strategy (self , retry_strategy : RetryStrategy ) -> "ModelTrainer" :
1008
+ def with_retry_strategy (self , retry_strategy : RetryStrategy ) -> "ModelTrainer" : # noqa: D412
1010
1009
"""Set the retry strategy for the training job.
1011
1010
1012
1011
Example:
@@ -1031,7 +1030,9 @@ def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer":
1031
1030
self ._retry_strategy = retry_strategy
1032
1031
return self
1033
1032
1034
- def with_infra_check_config (self , infra_check_config : InfraCheckConfig ) -> "ModelTrainer" :
1033
+ def with_infra_check_config (
1034
+ self , infra_check_config : InfraCheckConfig
1035
+ ) -> "ModelTrainer" : # noqa: D412
1035
1036
"""Set the infra check configuration for the training job.
1036
1037
1037
1038
Example:
@@ -1058,7 +1059,7 @@ def with_infra_check_config(self, infra_check_config: InfraCheckConfig) -> "Mode
1058
1059
1059
1060
def with_session_chaining_config (
1060
1061
self , session_chaining_config : SessionChainingConfig
1061
- ) -> "ModelTrainer" :
1062
+ ) -> "ModelTrainer" : # noqa: D412
1062
1063
"""Set the session chaining configuration for the training job.
1063
1064
1064
1065
Example:
@@ -1083,17 +1084,35 @@ def with_session_chaining_config(
1083
1084
self ._session_chaining_config = session_chaining_config
1084
1085
return self
1085
1086
1086
- def with_remote_debug_config (self , remote_debug_config : RemoteDebugConfig ) -> "ModelTrainer" :
1087
+ def with_remote_debug_config (
1088
+ self , remote_debug_config : RemoteDebugConfig
1089
+ ) -> "ModelTrainer" : # noqa: D412
1087
1090
"""Set the remote debug configuration for the training job.
1088
1091
1092
+ Example:
1093
+
1094
+ .. code:: python
1095
+
1096
+ from sagemaker.modules.train import ModelTrainer
1097
+ from sagemaker.modules.configs import RemoteDebugConfig
1098
+
1099
+ remote_debug_config = RemoteDebugConfig(
1100
+ enable_remote_debug=True,
1101
+ )
1102
+ model_trainer = ModelTrainer(
1103
+ ...
1104
+ ).with_remote_debug_config(remote_debug_config)
1105
+
1089
1106
Args:
1090
1107
remote_debug_config (RemoteDebugConfig):
1091
1108
The remote debug configuration for the training job.
1092
1109
"""
1093
1110
self ._remote_debug_config = remote_debug_config
1094
1111
return self
1095
1112
1096
- def with_metric_definitions (self , metric_definitions : List [MetricDefinition ]) -> "ModelTrainer" :
1113
+ def with_metric_definitions (
1114
+ self , metric_definitions : List [MetricDefinition ]
1115
+ ) -> "ModelTrainer" : # noqa: D412
1097
1116
"""Set the metric definitions for the training job.
1098
1117
1099
1118
Example:
@@ -1106,7 +1125,7 @@ def with_metric_definitions(self, metric_definitions: List[MetricDefinition]) ->
1106
1125
metric_definitions = [
1107
1126
MetricDefinition(
1108
1127
name="loss",
1109
- regex="Loss: (.*?); ",
1128
+ regex="Loss: (.*?)",
1110
1129
)
1111
1130
]
1112
1131
0 commit comments