|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +from __future__ import absolute_import |
| 14 | + |
| 15 | +import pytest |
| 16 | +from datetime import datetime |
| 17 | +from mock import Mock, patch |
| 18 | + |
| 19 | +from sagemaker.local.entities import _LocalTrainingJob |
| 20 | + |
| 21 | + |
| 22 | +class TestLocalTrainingJobFinalMetrics: |
| 23 | + """Test cases for FinalMetricDataList functionality in _LocalTrainingJob.""" |
| 24 | + |
| 25 | + def test_describe_includes_final_metric_data_list(self): |
| 26 | + """Test that describe() includes FinalMetricDataList field.""" |
| 27 | + container = Mock() |
| 28 | + job = _LocalTrainingJob(container) |
| 29 | + job.training_job_name = "test-job" |
| 30 | + job.state = "Completed" |
| 31 | + job.start_time = datetime.now() |
| 32 | + job.end_time = datetime.now() |
| 33 | + job.model_artifacts = "/path/to/model" |
| 34 | + job.output_data_config = {} |
| 35 | + job.environment = {} |
| 36 | + |
| 37 | + response = job.describe() |
| 38 | + |
| 39 | + assert "FinalMetricDataList" in response |
| 40 | + assert isinstance(response["FinalMetricDataList"], list) |
| 41 | + |
| 42 | + def test_extract_final_metrics_no_logs(self): |
| 43 | + """Test _extract_final_metrics returns empty list when no logs.""" |
| 44 | + container = Mock() |
| 45 | + container.logs = None |
| 46 | + job = _LocalTrainingJob(container) |
| 47 | + |
| 48 | + result = job._extract_final_metrics() |
| 49 | + |
| 50 | + assert result == [] |
| 51 | + |
| 52 | + def test_extract_final_metrics_no_metric_definitions(self): |
| 53 | + """Test _extract_final_metrics returns empty list when no metric definitions.""" |
| 54 | + container = Mock() |
| 55 | + container.logs = "some logs" |
| 56 | + container.metric_definitions = [] |
| 57 | + job = _LocalTrainingJob(container) |
| 58 | + |
| 59 | + result = job._extract_final_metrics() |
| 60 | + |
| 61 | + assert result == [] |
| 62 | + |
| 63 | + def test_extract_final_metrics_with_valid_metrics(self): |
| 64 | + """Test _extract_final_metrics extracts metrics correctly.""" |
| 65 | + container = Mock() |
| 66 | + container.logs = "Training started\nGAN_loss=0.138318;\nTraining complete" |
| 67 | + container.metric_definitions = [ |
| 68 | + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} |
| 69 | + ] |
| 70 | + job = _LocalTrainingJob(container) |
| 71 | + job.end_time = datetime(2023, 1, 1, 12, 0, 0) |
| 72 | + |
| 73 | + result = job._extract_final_metrics() |
| 74 | + |
| 75 | + assert len(result) == 1 |
| 76 | + assert result[0]["MetricName"] == "ganloss" |
| 77 | + assert result[0]["Value"] == 0.138318 |
| 78 | + assert result[0]["Timestamp"] == job.end_time |
| 79 | + |
| 80 | + def test_extract_final_metrics_multiple_matches_uses_last(self): |
| 81 | + """Test _extract_final_metrics uses the last match for each metric.""" |
| 82 | + container = Mock() |
| 83 | + container.logs = "GAN_loss=0.5;\nGAN_loss=0.3;\nGAN_loss=0.138318;" |
| 84 | + container.metric_definitions = [ |
| 85 | + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} |
| 86 | + ] |
| 87 | + job = _LocalTrainingJob(container) |
| 88 | + job.end_time = datetime(2023, 1, 1, 12, 0, 0) |
| 89 | + |
| 90 | + result = job._extract_final_metrics() |
| 91 | + |
| 92 | + assert len(result) == 1 |
| 93 | + assert result[0]["Value"] == 0.138318 |
| 94 | + |
| 95 | + def test_extract_final_metrics_multiple_metrics(self): |
| 96 | + """Test _extract_final_metrics handles multiple different metrics.""" |
| 97 | + container = Mock() |
| 98 | + container.logs = "GAN_loss=0.138318;\nAccuracy=0.95;\nLoss=1.234;" |
| 99 | + container.metric_definitions = [ |
| 100 | + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}, |
| 101 | + {"Name": "accuracy", "Regex": r"Accuracy=([\d\.]+);"}, |
| 102 | + {"Name": "loss", "Regex": r"Loss=([\d\.]+);"} |
| 103 | + ] |
| 104 | + job = _LocalTrainingJob(container) |
| 105 | + job.end_time = datetime(2023, 1, 1, 12, 0, 0) |
| 106 | + |
| 107 | + result = job._extract_final_metrics() |
| 108 | + |
| 109 | + assert len(result) == 3 |
| 110 | + metric_names = [m["MetricName"] for m in result] |
| 111 | + assert "ganloss" in metric_names |
| 112 | + assert "accuracy" in metric_names |
| 113 | + assert "loss" in metric_names |
| 114 | + |
| 115 | + def test_extract_final_metrics_no_matches(self): |
| 116 | + """Test _extract_final_metrics returns empty list when regex doesn't match.""" |
| 117 | + container = Mock() |
| 118 | + container.logs = "Training started\nTraining complete" |
| 119 | + container.metric_definitions = [ |
| 120 | + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} |
| 121 | + ] |
| 122 | + job = _LocalTrainingJob(container) |
| 123 | + |
| 124 | + result = job._extract_final_metrics() |
| 125 | + |
| 126 | + assert result == [] |
| 127 | + |
| 128 | + def test_extract_final_metrics_invalid_metric_definition(self): |
| 129 | + """Test _extract_final_metrics skips invalid metric definitions.""" |
| 130 | + container = Mock() |
| 131 | + container.logs = "GAN_loss=0.138318;" |
| 132 | + container.metric_definitions = [ |
| 133 | + {"Name": "ganloss"}, # Missing Regex |
| 134 | + {"Regex": r"GAN_loss=([\d\.]+);"}, # Missing Name |
| 135 | + {"Name": "valid", "Regex": r"GAN_loss=([\d\.]+);"} # Valid |
| 136 | + ] |
| 137 | + job = _LocalTrainingJob(container) |
| 138 | + job.end_time = datetime(2023, 1, 1, 12, 0, 0) |
| 139 | + |
| 140 | + result = job._extract_final_metrics() |
| 141 | + |
| 142 | + assert len(result) == 1 |
| 143 | + assert result[0]["MetricName"] == "valid" |
| 144 | + |
| 145 | + @patch("sagemaker.local.entities.datetime") |
| 146 | + def test_extract_final_metrics_uses_current_time_when_no_end_time(self, mock_datetime): |
| 147 | + """Test _extract_final_metrics uses current time when end_time is None.""" |
| 148 | + container = Mock() |
| 149 | + container.logs = "GAN_loss=0.138318;" |
| 150 | + container.metric_definitions = [ |
| 151 | + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} |
| 152 | + ] |
| 153 | + job = _LocalTrainingJob(container) |
| 154 | + job.end_time = None |
| 155 | + |
| 156 | + mock_now = datetime(2023, 1, 1, 12, 0, 0) |
| 157 | + mock_datetime.now.return_value = mock_now |
| 158 | + |
| 159 | + result = job._extract_final_metrics() |
| 160 | + |
| 161 | + assert len(result) == 1 |
| 162 | + assert result[0]["Timestamp"] == mock_now |
| 163 | + |
| 164 | + @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") |
| 165 | + def test_integration_describe_training_job_with_metrics(self, mock_train): |
| 166 | + """Integration test: describe_training_job includes FinalMetricDataList.""" |
| 167 | + from sagemaker.local.local_session import LocalSagemakerClient |
| 168 | + |
| 169 | + local_sagemaker_client = LocalSagemakerClient() |
| 170 | + |
| 171 | + algo_spec = {"TrainingImage": "my-image:1.0"} |
| 172 | + input_data_config = [{ |
| 173 | + "ChannelName": "training", |
| 174 | + "DataSource": { |
| 175 | + "S3DataSource": { |
| 176 | + "S3DataDistributionType": "FullyReplicated", |
| 177 | + "S3Uri": "s3://bucket/data" |
| 178 | + } |
| 179 | + } |
| 180 | + }] |
| 181 | + output_data_config = {} |
| 182 | + resource_config = {"InstanceType": "local", "InstanceCount": 1} |
| 183 | + |
| 184 | + # Create training job |
| 185 | + local_sagemaker_client.create_training_job( |
| 186 | + "test-job", |
| 187 | + algo_spec, |
| 188 | + output_data_config, |
| 189 | + resource_config, |
| 190 | + InputDataConfig=input_data_config, |
| 191 | + HyperParameters={} |
| 192 | + ) |
| 193 | + |
| 194 | + # Mock the container logs and metric definitions |
| 195 | + training_job = local_sagemaker_client._training_jobs["test-job"] |
| 196 | + training_job.container.logs = "GAN_loss=0.138318;" |
| 197 | + training_job.container.metric_definitions = [ |
| 198 | + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} |
| 199 | + ] |
| 200 | + |
| 201 | + response = local_sagemaker_client.describe_training_job("test-job") |
| 202 | + |
| 203 | + assert "FinalMetricDataList" in response |
| 204 | + assert len(response["FinalMetricDataList"]) == 1 |
| 205 | + assert response["FinalMetricDataList"][0]["MetricName"] == "ganloss" |
| 206 | + assert response["FinalMetricDataList"][0]["Value"] == 0.138318 |
0 commit comments