Skip to content

Commit 20e58e8

Browse files
committed
feat: Add support for MetricDefinitions in ModelTrainer and update pydocs
1 parent 13ad978 commit 20e58e8

File tree

3 files changed

+123
-0
lines changed

3 files changed

+123
-0
lines changed

src/sagemaker/modules/configs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
InstanceGroup,
4646
TensorBoardOutputConfig,
4747
CheckpointConfig,
48+
MetricDefinition,
4849
)
4950

5051
from sagemaker.modules.utils import convert_unassigned_to_none
@@ -71,6 +72,7 @@
7172
"Compute",
7273
"Networking",
7374
"InputData",
75+
"MetricDefinition",
7476
]
7577

7678

src/sagemaker/modules/train/model_trainer.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
TensorBoardOutputConfig,
6868
CheckpointConfig,
6969
InputData,
70+
MetricDefinition,
7071
)
7172

7273
from sagemaker.modules.local_core.local_container import _LocalContainer
@@ -237,6 +238,7 @@ class ModelTrainer(BaseModel):
237238
_infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None)
238239
_session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None)
239240
_remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None)
241+
_metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None)
240242

241243
_temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)
242244

@@ -587,6 +589,7 @@ def train(
587589
training_image_config=self.training_image_config,
588590
container_entrypoint=container_entrypoint,
589591
container_arguments=container_arguments,
592+
metric_definitions=self._metric_definitions,
590593
)
591594

592595
resource_config = self.compute._to_resource_config()
@@ -979,6 +982,23 @@ def with_tensorboard_output_config(
979982
) -> "ModelTrainer":
980983
"""Set the TensorBoard output configuration.
981984
985+
Example:
986+
987+
.. code:: python
988+
989+
from sagemaker.modules.train import ModelTrainer
990+
from sagemaker.modules.configs import TensorBoardOutputConfig
991+
992+
tensorboard_output_config = TensorBoardOutputConfig(
993+
s3_output_path="s3://bucket-name/tensorboard",
994+
local_path="/opt/ml/output/tensorboard"
995+
)
996+
997+
model_trainer = ModelTrainer(
998+
...
999+
).with_tensorboard_output_config(tensorboard_output_config)
1000+
1001+
9821002
Args:
9831003
tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig):
9841004
The TensorBoard output configuration.
@@ -989,6 +1009,21 @@ def with_tensorboard_output_config(
9891009
def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer":
9901010
"""Set the retry strategy for the training job.
9911011
1012+
Example:
1013+
1014+
.. code:: python
1015+
1016+
from sagemaker.modules.train import ModelTrainer
1017+
from sagemaker.modules.configs import RetryStrategy
1018+
1019+
retry_strategy = RetryStrategy(
1020+
maximum_retry_attempts=3,
1021+
)
1022+
1023+
model_trainer = ModelTrainer(
1024+
...
1025+
).with_retry_strategy(retry_strategy)
1026+
9921027
Args:
9931028
retry_strategy (RetryStrategy):
9941029
The retry strategy for the training job.
@@ -999,6 +1034,21 @@ def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer":
9991034
def with_infra_check_config(self, infra_check_config: InfraCheckConfig) -> "ModelTrainer":
10001035
"""Set the infra check configuration for the training job.
10011036
1037+
Example:
1038+
1039+
.. code:: python
1040+
1041+
from sagemaker.modules.train import ModelTrainer
1042+
from sagemaker.modules.configs import InfraCheckConfig
1043+
1044+
infra_check_config = InfraCheckConfig(
1045+
enable_infra_check=True,
1046+
)
1047+
1048+
model_trainer = ModelTrainer(
1049+
...
1050+
).with_infra_check_config(infra_check_config)
1051+
10021052
Args:
10031053
infra_check_config (InfraCheckConfig):
10041054
The infra check configuration for the training job.
@@ -1011,6 +1061,21 @@ def with_session_chaining_config(
10111061
) -> "ModelTrainer":
10121062
"""Set the session chaining configuration for the training job.
10131063
1064+
Example:
1065+
1066+
.. code:: python
1067+
1068+
from sagemaker.modules.train import ModelTrainer
1069+
from sagemaker.modules.configs import SessionChainingConfig
1070+
1071+
session_chaining_config = SessionChainingConfig(
1072+
enable_session_tag_chaining=True,
1073+
)
1074+
1075+
model_trainer = ModelTrainer(
1076+
...
1077+
).with_session_chaining_config(session_chaining_config
1078+
10141079
Args:
10151080
session_chaining_config (SessionChainingConfig):
10161081
The session chaining configuration for the training job.
@@ -1027,3 +1092,31 @@ def with_remote_debug_config(self, remote_debug_config: RemoteDebugConfig) -> "M
10271092
"""
10281093
self._remote_debug_config = remote_debug_config
10291094
return self
1095+
1096+
def with_metric_definitions(self, metric_definitions: List[MetricDefinition]) -> "ModelTrainer":
1097+
"""Set the metric definitions for the training job.
1098+
1099+
Example:
1100+
1101+
.. code:: python
1102+
1103+
from sagemaker.modules.train import ModelTrainer
1104+
from sagemaker.modules.configs import MetricDefinition
1105+
1106+
metric_definitions = [
1107+
MetricDefinition(
1108+
name="loss",
1109+
regex="Loss: (.*?);",
1110+
)
1111+
]
1112+
1113+
model_trainer = ModelTrainer(
1114+
...
1115+
).with_metric_definitions(metric_definitions)
1116+
1117+
Args:
1118+
metric_definitions (List[MetricDefinition]):
1119+
The metric definitions for the training job.
1120+
"""
1121+
self._metric_definitions = metric_definitions
1122+
return self

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
FileSystemDataSource,
6363
Channel,
6464
DataSource,
65+
MetricDefinition,
6566
)
6667
from sagemaker.modules.distributed import Torchrun, SMP, MPI
6768
from sagemaker.modules.train.sm_recipes.utils import _load_recipes_cfg
@@ -654,6 +655,32 @@ def test_remote_debug_config(mock_training_job, modules_session):
654655
)
655656

656657

658+
@patch("sagemaker.modules.train.model_trainer.TrainingJob")
659+
def test_metric_definitions(mock_training_job, modules_session):
660+
image_uri = DEFAULT_IMAGE
661+
role = DEFAULT_ROLE
662+
metric_definitions = [
663+
MetricDefinition(
664+
name="loss",
665+
regex="Loss: (.*?);",
666+
)
667+
]
668+
669+
model_trainer = ModelTrainer(
670+
training_image=image_uri, sagemaker_session=modules_session, role=role
671+
).with_metric_definitions(metric_definitions)
672+
673+
with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data:
674+
mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix"
675+
model_trainer.train()
676+
677+
mock_training_job.create.assert_called_once()
678+
assert (
679+
mock_training_job.create.call_args.kwargs["algorithm_specification"].metric_definitions
680+
== metric_definitions
681+
)
682+
683+
657684
@patch("sagemaker.modules.train.model_trainer._get_unique_name")
658685
@patch("sagemaker.modules.train.model_trainer.TrainingJob")
659686
def test_model_trainer_full_init(mock_training_job, mock_unique_name, modules_session):
@@ -771,6 +798,7 @@ def mock_upload_data(path, bucket, key_prefix):
771798
training_input_mode=training_input_mode,
772799
training_image=training_image,
773800
algorithm_name=None,
801+
metric_definitions=None,
774802
container_entrypoint=DEFAULT_ENTRYPOINT,
775803
container_arguments=DEFAULT_ARGUMENTS,
776804
training_image_config=training_image_config,

0 commit comments

Comments
 (0)