|
16 | 16 | import pytest |
17 | 17 | import urllib3 |
18 | 18 | import os |
| 19 | +from datetime import datetime |
19 | 20 | from botocore.exceptions import ClientError |
20 | 21 | from mock import Mock, patch |
21 | 22 | from tests.unit import DATA_DIR, SAGEMAKER_CONFIG_SESSION |
|
25 | 26 | from sagemaker.workflow.pipeline import Pipeline |
26 | 27 | from tests.unit.sagemaker.workflow.helpers import CustomStep |
27 | 28 | from sagemaker.local.local_session import LocalSession |
28 | | -from sagemaker.local.entities import _LocalPipelineExecution |
| 29 | +from sagemaker.local.entities import _LocalPipelineExecution, _LocalTrainingJob |
29 | 30 |
|
30 | 31 |
|
31 | 32 | OK_RESPONSE = urllib3.HTTPResponse() |
@@ -1100,3 +1101,190 @@ def test_config_setter(): |
1100 | 1101 |
|
1101 | 1102 | with pytest.raises(jsonschema.ValidationError): |
1102 | 1103 | session.config = INVALID_LOCAL_MODE_CONFIG |
| 1104 | + |
| 1105 | + |
| 1106 | +class TestLocalTrainingJobFinalMetrics: |
| 1107 | + """Test cases for FinalMetricDataList functionality in _LocalTrainingJob.""" |
| 1108 | + |
| 1109 | + def test_describe_includes_final_metric_data_list(self): |
| 1110 | + """Test that describe() includes FinalMetricDataList field.""" |
| 1111 | + container = Mock() |
| 1112 | + container.logs = None |
| 1113 | + container.metric_definitions = [] |
| 1114 | + job = _LocalTrainingJob(container) |
| 1115 | + job.training_job_name = "test-job" |
| 1116 | + job.state = "Completed" |
| 1117 | + job.start_time = datetime.now() |
| 1118 | + job.end_time = datetime.now() |
| 1119 | + job.model_artifacts = "/path/to/model" |
| 1120 | + job.output_data_config = {} |
| 1121 | + job.environment = {} |
| 1122 | + |
| 1123 | + response = job.describe() |
| 1124 | + |
| 1125 | + assert "FinalMetricDataList" in response |
| 1126 | + assert isinstance(response["FinalMetricDataList"], list) |
| 1127 | + |
| 1128 | + def test_extract_final_metrics_no_logs(self): |
| 1129 | + """Test _extract_final_metrics returns empty list when no logs.""" |
| 1130 | + container = Mock() |
| 1131 | + container.logs = None |
| 1132 | + job = _LocalTrainingJob(container) |
| 1133 | + |
| 1134 | + result = job._extract_final_metrics() |
| 1135 | + |
| 1136 | + assert result == [] |
| 1137 | + |
| 1138 | + def test_extract_final_metrics_no_metric_definitions(self): |
| 1139 | + """Test _extract_final_metrics returns empty list when no metric definitions.""" |
| 1140 | + container = Mock() |
| 1141 | + container.logs = "some logs" |
| 1142 | + container.metric_definitions = [] |
| 1143 | + job = _LocalTrainingJob(container) |
| 1144 | + |
| 1145 | + result = job._extract_final_metrics() |
| 1146 | + |
| 1147 | + assert result == [] |
| 1148 | + |
| 1149 | + def test_extract_final_metrics_with_valid_metrics(self): |
| 1150 | + """Test _extract_final_metrics extracts metrics correctly.""" |
| 1151 | + container = Mock() |
| 1152 | + container.logs = "Training started\nGAN_loss=0.138318;\nTraining complete" |
| 1153 | + container.metric_definitions = [ |
| 1154 | + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} |
| 1155 | + ] |
| 1156 | + job = _LocalTrainingJob(container) |
| 1157 | + job.end_time = datetime(2023, 1, 1, 12, 0, 0) |
| 1158 | + |
| 1159 | + result = job._extract_final_metrics() |
| 1160 | + |
| 1161 | + assert len(result) == 1 |
| 1162 | + assert result[0]["MetricName"] == "ganloss" |
| 1163 | + assert result[0]["Value"] == 0.138318 |
| 1164 | + assert result[0]["Timestamp"] == job.end_time |
| 1165 | + |
| 1166 | + def test_extract_final_metrics_multiple_matches_uses_last(self): |
| 1167 | + """Test _extract_final_metrics uses the last match for each metric.""" |
| 1168 | + container = Mock() |
| 1169 | + container.logs = "GAN_loss=0.5;\nGAN_loss=0.3;\nGAN_loss=0.138318;" |
| 1170 | + container.metric_definitions = [ |
| 1171 | + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} |
| 1172 | + ] |
| 1173 | + job = _LocalTrainingJob(container) |
| 1174 | + job.end_time = datetime(2023, 1, 1, 12, 0, 0) |
| 1175 | + |
| 1176 | + result = job._extract_final_metrics() |
| 1177 | + |
| 1178 | + assert len(result) == 1 |
| 1179 | + assert result[0]["Value"] == 0.138318 |
| 1180 | + |
| 1181 | + def test_extract_final_metrics_multiple_metrics(self): |
| 1182 | + """Test _extract_final_metrics handles multiple different metrics.""" |
| 1183 | + container = Mock() |
| 1184 | + container.logs = "GAN_loss=0.138318;\nAccuracy=0.95;\nLoss=1.234;" |
| 1185 | + container.metric_definitions = [ |
| 1186 | + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}, |
| 1187 | + {"Name": "accuracy", "Regex": r"Accuracy=([\d\.]+);"}, |
| 1188 | + {"Name": "loss", "Regex": r"Loss=([\d\.]+);"} |
| 1189 | + ] |
| 1190 | + job = _LocalTrainingJob(container) |
| 1191 | + job.end_time = datetime(2023, 1, 1, 12, 0, 0) |
| 1192 | + |
| 1193 | + result = job._extract_final_metrics() |
| 1194 | + |
| 1195 | + assert len(result) == 3 |
| 1196 | + metric_names = [m["MetricName"] for m in result] |
| 1197 | + assert "ganloss" in metric_names |
| 1198 | + assert "accuracy" in metric_names |
| 1199 | + assert "loss" in metric_names |
| 1200 | + |
| 1201 | + def test_extract_final_metrics_no_matches(self): |
| 1202 | + """Test _extract_final_metrics returns empty list when regex doesn't match.""" |
| 1203 | + container = Mock() |
| 1204 | + container.logs = "Training started\nTraining complete" |
| 1205 | + container.metric_definitions = [ |
| 1206 | + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} |
| 1207 | + ] |
| 1208 | + job = _LocalTrainingJob(container) |
| 1209 | + |
| 1210 | + result = job._extract_final_metrics() |
| 1211 | + |
| 1212 | + assert result == [] |
| 1213 | + |
| 1214 | + def test_extract_final_metrics_invalid_metric_definition(self): |
| 1215 | + """Test _extract_final_metrics skips invalid metric definitions.""" |
| 1216 | + container = Mock() |
| 1217 | + container.logs = "GAN_loss=0.138318;" |
| 1218 | + container.metric_definitions = [ |
| 1219 | + {"Name": "ganloss"}, # Missing Regex |
| 1220 | + {"Regex": r"GAN_loss=([\d\.]+);"}, # Missing Name |
| 1221 | + {"Name": "valid", "Regex": r"GAN_loss=([\d\.]+);"} # Valid |
| 1222 | + ] |
| 1223 | + job = _LocalTrainingJob(container) |
| 1224 | + job.end_time = datetime(2023, 1, 1, 12, 0, 0) |
| 1225 | + |
| 1226 | + result = job._extract_final_metrics() |
| 1227 | + |
| 1228 | + assert len(result) == 1 |
| 1229 | + assert result[0]["MetricName"] == "valid" |
| 1230 | + |
| 1231 | + @patch("sagemaker.local.entities.datetime") |
| 1232 | + def test_extract_final_metrics_uses_current_time_when_no_end_time(self, mock_datetime): |
| 1233 | + """Test _extract_final_metrics uses current time when end_time is None.""" |
| 1234 | + container = Mock() |
| 1235 | + container.logs = "GAN_loss=0.138318;" |
| 1236 | + container.metric_definitions = [ |
| 1237 | + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} |
| 1238 | + ] |
| 1239 | + job = _LocalTrainingJob(container) |
| 1240 | + job.end_time = None |
| 1241 | + |
| 1242 | + mock_now = datetime(2023, 1, 1, 12, 0, 0) |
| 1243 | + mock_datetime.now.return_value = mock_now |
| 1244 | + |
| 1245 | + result = job._extract_final_metrics() |
| 1246 | + |
| 1247 | + assert len(result) == 1 |
| 1248 | + assert result[0]["Timestamp"] == mock_now |
| 1249 | + |
| 1250 | + @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") |
| 1251 | + def test_integration_describe_training_job_with_metrics(self, mock_train): |
| 1252 | + """Integration test: describe_training_job includes FinalMetricDataList.""" |
| 1253 | + local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() |
| 1254 | + |
| 1255 | + algo_spec = {"TrainingImage": "my-image:1.0"} |
| 1256 | + input_data_config = [{ |
| 1257 | + "ChannelName": "training", |
| 1258 | + "DataSource": { |
| 1259 | + "S3DataSource": { |
| 1260 | + "S3DataDistributionType": "FullyReplicated", |
| 1261 | + "S3Uri": "s3://bucket/data" |
| 1262 | + } |
| 1263 | + } |
| 1264 | + }] |
| 1265 | + output_data_config = {} |
| 1266 | + resource_config = {"InstanceType": "local", "InstanceCount": 1} |
| 1267 | + |
| 1268 | + # Create training job |
| 1269 | + local_sagemaker_client.create_training_job( |
| 1270 | + "test-job", |
| 1271 | + algo_spec, |
| 1272 | + output_data_config, |
| 1273 | + resource_config, |
| 1274 | + InputDataConfig=input_data_config, |
| 1275 | + HyperParameters={} |
| 1276 | + ) |
| 1277 | + |
| 1278 | + # Mock the container logs and metric definitions |
| 1279 | + training_job = local_sagemaker_client._training_jobs["test-job"] |
| 1280 | + training_job.container.logs = "GAN_loss=0.138318;" |
| 1281 | + training_job.container.metric_definitions = [ |
| 1282 | + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} |
| 1283 | + ] |
| 1284 | + |
| 1285 | + response = local_sagemaker_client.describe_training_job("test-job") |
| 1286 | + |
| 1287 | + assert "FinalMetricDataList" in response |
| 1288 | + assert len(response["FinalMetricDataList"]) == 1 |
| 1289 | + assert response["FinalMetricDataList"][0]["MetricName"] == "ganloss" |
| 1290 | + assert response["FinalMetricDataList"][0]["Value"] == 0.138318 |
0 commit comments