diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index 1ada10dff3..8fdf88e735 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -42,6 +42,7 @@ RemoteDebugConfig, SessionChainingConfig, InstanceGroup, + MetricDefinition, ) from sagemaker.modules.utils import convert_unassigned_to_none @@ -68,6 +69,7 @@ "Compute", "Networking", "InputData", + "MetricDefinition", ] diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 7d83766c9f..eaabe5972a 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -66,6 +66,7 @@ RemoteDebugConfig, SessionChainingConfig, InputData, + MetricDefinition, ) from sagemaker.modules.local_core.local_container import _LocalContainer @@ -239,6 +240,7 @@ class ModelTrainer(BaseModel): _infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None) _session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None) _remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None) + _metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None) _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) @@ -696,6 +698,7 @@ def train( training_image_config=self.training_image_config, container_entrypoint=container_entrypoint, container_arguments=container_arguments, + metric_definitions=self._metric_definitions, ) resource_config = self.compute._to_resource_config() @@ -1290,3 +1293,33 @@ def with_checkpoint_config( """ self.checkpoint_config = checkpoint_config or configs.CheckpointConfig() return self + + def with_metric_definitions( + self, metric_definitions: List[MetricDefinition] + ) -> "ModelTrainer": # noqa: D412 + """Set the metric definitions for the training job. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import MetricDefinition + + metric_definitions = [ + MetricDefinition( + name="loss", + regex="Loss: (.*?)", + ) + ] + + model_trainer = ModelTrainer( + ... + ).with_metric_definitions(metric_definitions) + + Args: + metric_definitions (List[MetricDefinition]): + The metric definitions for the training job. + """ + self._metric_definitions = metric_definitions + return self diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index cf38f26334..23ea167ecf 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -64,6 +64,7 @@ FileSystemDataSource, Channel, DataSource, + MetricDefinition, ) from sagemaker.modules.distributed import Torchrun, SMP, MPI from sagemaker.modules.train.sm_recipes.utils import _load_recipes_cfg @@ -705,6 +706,32 @@ def test_remote_debug_config(mock_training_job, modules_session): ) +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_metric_definitions(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + metric_definitions = [ + MetricDefinition( + name="loss", + regex="Loss: (.*?);", + ) + ] + + model_trainer = ModelTrainer( + training_image=image_uri, sagemaker_session=modules_session, role=role + ).with_metric_definitions(metric_definitions) + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["algorithm_specification"].metric_definitions + == metric_definitions + ) + + @patch("sagemaker.modules.train.model_trainer._get_unique_name") @patch("sagemaker.modules.train.model_trainer.TrainingJob") def test_model_trainer_full_init(mock_training_job, mock_unique_name, modules_session): @@ -822,6 +849,7 @@ def mock_upload_data(path, bucket, key_prefix): training_input_mode=training_input_mode, training_image=training_image, algorithm_name=None, + metric_definitions=None, container_entrypoint=DEFAULT_ENTRYPOINT, container_arguments=DEFAULT_ARGUMENTS, training_image_config=training_image_config,