Skip to content

Commit b5714a8

Browse files
author
Chad Chiang
committed
feat: Add support for MetricDefinitions in ModelTrainer
1 parent 829030a commit b5714a8

File tree

3 files changed

+54
-0
lines changed

3 files changed

+54
-0
lines changed

src/sagemaker/modules/configs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
RemoteDebugConfig,
4343
SessionChainingConfig,
4444
InstanceGroup,
45+
MetricDefinition,
4546
)
4647

4748
from sagemaker.modules.utils import convert_unassigned_to_none
@@ -68,6 +69,7 @@
6869
"Compute",
6970
"Networking",
7071
"InputData",
72+
"MetricDefinition",
7173
]
7274

7375

src/sagemaker/modules/train/model_trainer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
RemoteDebugConfig,
6767
SessionChainingConfig,
6868
InputData,
69+
MetricDefinition,
6970
)
7071

7172
from sagemaker.modules.local_core.local_container import _LocalContainer
@@ -1290,3 +1291,27 @@ def with_checkpoint_config(
12901291
"""
12911292
self.checkpoint_config = checkpoint_config or configs.CheckpointConfig()
12921293
return self
1294+
1295+
def with_metric_definitions(
1296+
self, metric_definitions: List[MetricDefinition]
1297+
) -> "ModelTrainer": # noqa: D412
1298+
"""Set the metric definitions for the training job.
1299+
Example:
1300+
.. code:: python
1301+
from sagemaker.modules.train import ModelTrainer
1302+
from sagemaker.modules.configs import MetricDefinition
1303+
metric_definitions = [
1304+
MetricDefinition(
1305+
name="loss",
1306+
regex="Loss: (.*?)",
1307+
)
1308+
]
1309+
model_trainer = ModelTrainer(
1310+
...
1311+
).with_metric_definitions(metric_definitions)
1312+
Args:
1313+
metric_definitions (List[MetricDefinition]):
1314+
The metric definitions for the training job.
1315+
"""
1316+
self._metric_definitions = metric_definitions
1317+
return self

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
FileSystemDataSource,
6565
Channel,
6666
DataSource,
67+
MetricDefinition,
6768
)
6869
from sagemaker.modules.distributed import Torchrun, SMP, MPI
6970
from sagemaker.modules.train.sm_recipes.utils import _load_recipes_cfg
@@ -705,6 +706,32 @@ def test_remote_debug_config(mock_training_job, modules_session):
705706
)
706707

707708

709+
@patch("sagemaker.modules.train.model_trainer.TrainingJob")
710+
def test_metric_definitions(mock_training_job, modules_session):
711+
image_uri = DEFAULT_IMAGE
712+
role = DEFAULT_ROLE
713+
metric_definitions = [
714+
MetricDefinition(
715+
name="loss",
716+
regex="Loss: (.*?);",
717+
)
718+
]
719+
720+
model_trainer = ModelTrainer(
721+
training_image=image_uri, sagemaker_session=modules_session, role=role
722+
).with_metric_definitions(metric_definitions)
723+
724+
with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data:
725+
mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix"
726+
model_trainer.train()
727+
728+
mock_training_job.create.assert_called_once()
729+
assert (
730+
mock_training_job.create.call_args.kwargs["algorithm_specification"].metric_definitions
731+
== metric_definitions
732+
)
733+
734+
708735
@patch("sagemaker.modules.train.model_trainer._get_unique_name")
709736
@patch("sagemaker.modules.train.model_trainer.TrainingJob")
710737
def test_model_trainer_full_init(mock_training_job, mock_unique_name, modules_session):

0 commit comments

Comments
 (0)