67
67
TensorBoardOutputConfig ,
68
68
CheckpointConfig ,
69
69
InputData ,
70
+ MetricDefinition ,
70
71
)
71
72
72
73
from sagemaker .modules .local_core .local_container import _LocalContainer
@@ -237,6 +238,7 @@ class ModelTrainer(BaseModel):
237
238
_infra_check_config : Optional [InfraCheckConfig ] = PrivateAttr (default = None )
238
239
_session_chaining_config : Optional [SessionChainingConfig ] = PrivateAttr (default = None )
239
240
_remote_debug_config : Optional [RemoteDebugConfig ] = PrivateAttr (default = None )
241
+ _metric_definitions : Optional [List [MetricDefinition ]] = PrivateAttr (default = None )
240
242
241
243
_temp_recipe_train_dir : Optional [TemporaryDirectory ] = PrivateAttr (default = None )
242
244
@@ -587,6 +589,7 @@ def train(
587
589
training_image_config = self .training_image_config ,
588
590
container_entrypoint = container_entrypoint ,
589
591
container_arguments = container_arguments ,
592
+ metric_definitions = self ._metric_definitions ,
590
593
)
591
594
592
595
resource_config = self .compute ._to_resource_config ()
@@ -979,6 +982,23 @@ def with_tensorboard_output_config(
979
982
) -> "ModelTrainer" :
980
983
"""Set the TensorBoard output configuration.
981
984
985
+ Example:
986
+
987
+ .. code:: python
988
+
989
+ from sagemaker.modules.train import ModelTrainer
990
+ from sagemaker.modules.configs import TensorBoardOutputConfig
991
+
992
+ tensorboard_output_config = TensorBoardOutputConfig(
993
+ s3_output_path="s3://bucket-name/tensorboard",
994
+ local_path="/opt/ml/output/tensorboard"
995
+ )
996
+
997
+ model_trainer = ModelTrainer(
998
+ ...
999
+ ).with_tensorboard_output_config(tensorboard_output_config)
1000
+
1001
+
982
1002
Args:
983
1003
tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig):
984
1004
The TensorBoard output configuration.
@@ -989,6 +1009,21 @@ def with_tensorboard_output_config(
989
1009
def with_retry_strategy (self , retry_strategy : RetryStrategy ) -> "ModelTrainer" :
990
1010
"""Set the retry strategy for the training job.
991
1011
1012
+ Example:
1013
+
1014
+ .. code:: python
1015
+
1016
+ from sagemaker.modules.train import ModelTrainer
1017
+ from sagemaker.modules.configs import RetryStrategy
1018
+
1019
+ retry_strategy = RetryStrategy(
1020
+ maximum_retry_attempts=3,
1021
+ )
1022
+
1023
+ model_trainer = ModelTrainer(
1024
+ ...
1025
+ ).with_retry_strategy(retry_strategy)
1026
+
992
1027
Args:
993
1028
retry_strategy (RetryStrategy):
994
1029
The retry strategy for the training job.
@@ -999,6 +1034,21 @@ def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer":
999
1034
def with_infra_check_config (self , infra_check_config : InfraCheckConfig ) -> "ModelTrainer" :
1000
1035
"""Set the infra check configuration for the training job.
1001
1036
1037
+ Example:
1038
+
1039
+ .. code:: python
1040
+
1041
+ from sagemaker.modules.train import ModelTrainer
1042
+ from sagemaker.modules.configs import InfraCheckConfig
1043
+
1044
+ infra_check_config = InfraCheckConfig(
1045
+ enable_infra_check=True,
1046
+ )
1047
+
1048
+ model_trainer = ModelTrainer(
1049
+ ...
1050
+ ).with_infra_check_config(infra_check_config)
1051
+
1002
1052
Args:
1003
1053
infra_check_config (InfraCheckConfig):
1004
1054
The infra check configuration for the training job.
@@ -1011,6 +1061,21 @@ def with_session_chaining_config(
1011
1061
) -> "ModelTrainer" :
1012
1062
"""Set the session chaining configuration for the training job.
1013
1063
1064
+ Example:
1065
+
1066
+ .. code:: python
1067
+
1068
+ from sagemaker.modules.train import ModelTrainer
1069
+ from sagemaker.modules.configs import SessionChainingConfig
1070
+
1071
+ session_chaining_config = SessionChainingConfig(
1072
+ enable_session_tag_chaining=True,
1073
+ )
1074
+
1075
+ model_trainer = ModelTrainer(
1076
+ ...
1077
+ ).with_session_chaining_config(session_chaining_config
1078
+
1014
1079
Args:
1015
1080
session_chaining_config (SessionChainingConfig):
1016
1081
The session chaining configuration for the training job.
@@ -1027,3 +1092,31 @@ def with_remote_debug_config(self, remote_debug_config: RemoteDebugConfig) -> "M
1027
1092
"""
1028
1093
self ._remote_debug_config = remote_debug_config
1029
1094
return self
1095
+
1096
+ def with_metric_definitions (self , metric_definitions : List [MetricDefinition ]) -> "ModelTrainer" :
1097
+ """Set the metric definitions for the training job.
1098
+
1099
+ Example:
1100
+
1101
+ .. code:: python
1102
+
1103
+ from sagemaker.modules.train import ModelTrainer
1104
+ from sagemaker.modules.configs import MetricDefinition
1105
+
1106
+ metric_definitions = [
1107
+ MetricDefinition(
1108
+ name="loss",
1109
+ regex="Loss: (.*?);",
1110
+ )
1111
+ ]
1112
+
1113
+ model_trainer = ModelTrainer(
1114
+ ...
1115
+ ).with_metric_definitions(metric_definitions)
1116
+
1117
+ Args:
1118
+ metric_definitions (List[MetricDefinition]):
1119
+ The metric definitions for the training job.
1120
+ """
1121
+ self ._metric_definitions = metric_definitions
1122
+ return self
0 commit comments