Skip to content

Commit 5d15929

Browse files
committed
chore: add more tests
1 parent fc756f5 commit 5d15929

File tree

1 file changed

+189
-1
lines changed

1 file changed

+189
-1
lines changed

tests/unit/sagemaker/local/test_local_session.py

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pytest
1717
import urllib3
1818
import os
19+
from datetime import datetime
1920
from botocore.exceptions import ClientError
2021
from mock import Mock, patch
2122
from tests.unit import DATA_DIR, SAGEMAKER_CONFIG_SESSION
@@ -25,7 +26,7 @@
2526
from sagemaker.workflow.pipeline import Pipeline
2627
from tests.unit.sagemaker.workflow.helpers import CustomStep
2728
from sagemaker.local.local_session import LocalSession
28-
from sagemaker.local.entities import _LocalPipelineExecution
29+
from sagemaker.local.entities import _LocalPipelineExecution, _LocalTrainingJob
2930

3031

3132
OK_RESPONSE = urllib3.HTTPResponse()
@@ -1100,3 +1101,190 @@ def test_config_setter():
11001101

11011102
with pytest.raises(jsonschema.ValidationError):
11021103
session.config = INVALID_LOCAL_MODE_CONFIG
1104+
1105+
1106+
class TestLocalTrainingJobFinalMetrics:
1107+
"""Test cases for FinalMetricDataList functionality in _LocalTrainingJob."""
1108+
1109+
def test_describe_includes_final_metric_data_list(self):
1110+
"""Test that describe() includes FinalMetricDataList field."""
1111+
container = Mock()
1112+
container.logs = None
1113+
container.metric_definitions = []
1114+
job = _LocalTrainingJob(container)
1115+
job.training_job_name = "test-job"
1116+
job.state = "Completed"
1117+
job.start_time = datetime.now()
1118+
job.end_time = datetime.now()
1119+
job.model_artifacts = "/path/to/model"
1120+
job.output_data_config = {}
1121+
job.environment = {}
1122+
1123+
response = job.describe()
1124+
1125+
assert "FinalMetricDataList" in response
1126+
assert isinstance(response["FinalMetricDataList"], list)
1127+
1128+
def test_extract_final_metrics_no_logs(self):
1129+
"""Test _extract_final_metrics returns empty list when no logs."""
1130+
container = Mock()
1131+
container.logs = None
1132+
job = _LocalTrainingJob(container)
1133+
1134+
result = job._extract_final_metrics()
1135+
1136+
assert result == []
1137+
1138+
def test_extract_final_metrics_no_metric_definitions(self):
1139+
"""Test _extract_final_metrics returns empty list when no metric definitions."""
1140+
container = Mock()
1141+
container.logs = "some logs"
1142+
container.metric_definitions = []
1143+
job = _LocalTrainingJob(container)
1144+
1145+
result = job._extract_final_metrics()
1146+
1147+
assert result == []
1148+
1149+
def test_extract_final_metrics_with_valid_metrics(self):
1150+
"""Test _extract_final_metrics extracts metrics correctly."""
1151+
container = Mock()
1152+
container.logs = "Training started\nGAN_loss=0.138318;\nTraining complete"
1153+
container.metric_definitions = [
1154+
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
1155+
]
1156+
job = _LocalTrainingJob(container)
1157+
job.end_time = datetime(2023, 1, 1, 12, 0, 0)
1158+
1159+
result = job._extract_final_metrics()
1160+
1161+
assert len(result) == 1
1162+
assert result[0]["MetricName"] == "ganloss"
1163+
assert result[0]["Value"] == 0.138318
1164+
assert result[0]["Timestamp"] == job.end_time
1165+
1166+
def test_extract_final_metrics_multiple_matches_uses_last(self):
1167+
"""Test _extract_final_metrics uses the last match for each metric."""
1168+
container = Mock()
1169+
container.logs = "GAN_loss=0.5;\nGAN_loss=0.3;\nGAN_loss=0.138318;"
1170+
container.metric_definitions = [
1171+
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
1172+
]
1173+
job = _LocalTrainingJob(container)
1174+
job.end_time = datetime(2023, 1, 1, 12, 0, 0)
1175+
1176+
result = job._extract_final_metrics()
1177+
1178+
assert len(result) == 1
1179+
assert result[0]["Value"] == 0.138318
1180+
1181+
def test_extract_final_metrics_multiple_metrics(self):
1182+
"""Test _extract_final_metrics handles multiple different metrics."""
1183+
container = Mock()
1184+
container.logs = "GAN_loss=0.138318;\nAccuracy=0.95;\nLoss=1.234;"
1185+
container.metric_definitions = [
1186+
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"},
1187+
{"Name": "accuracy", "Regex": r"Accuracy=([\d\.]+);"},
1188+
{"Name": "loss", "Regex": r"Loss=([\d\.]+);"}
1189+
]
1190+
job = _LocalTrainingJob(container)
1191+
job.end_time = datetime(2023, 1, 1, 12, 0, 0)
1192+
1193+
result = job._extract_final_metrics()
1194+
1195+
assert len(result) == 3
1196+
metric_names = [m["MetricName"] for m in result]
1197+
assert "ganloss" in metric_names
1198+
assert "accuracy" in metric_names
1199+
assert "loss" in metric_names
1200+
1201+
def test_extract_final_metrics_no_matches(self):
1202+
"""Test _extract_final_metrics returns empty list when regex doesn't match."""
1203+
container = Mock()
1204+
container.logs = "Training started\nTraining complete"
1205+
container.metric_definitions = [
1206+
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
1207+
]
1208+
job = _LocalTrainingJob(container)
1209+
1210+
result = job._extract_final_metrics()
1211+
1212+
assert result == []
1213+
1214+
def test_extract_final_metrics_invalid_metric_definition(self):
1215+
"""Test _extract_final_metrics skips invalid metric definitions."""
1216+
container = Mock()
1217+
container.logs = "GAN_loss=0.138318;"
1218+
container.metric_definitions = [
1219+
{"Name": "ganloss"}, # Missing Regex
1220+
{"Regex": r"GAN_loss=([\d\.]+);"}, # Missing Name
1221+
{"Name": "valid", "Regex": r"GAN_loss=([\d\.]+);"} # Valid
1222+
]
1223+
job = _LocalTrainingJob(container)
1224+
job.end_time = datetime(2023, 1, 1, 12, 0, 0)
1225+
1226+
result = job._extract_final_metrics()
1227+
1228+
assert len(result) == 1
1229+
assert result[0]["MetricName"] == "valid"
1230+
1231+
@patch("sagemaker.local.entities.datetime")
1232+
def test_extract_final_metrics_uses_current_time_when_no_end_time(self, mock_datetime):
1233+
"""Test _extract_final_metrics uses current time when end_time is None."""
1234+
container = Mock()
1235+
container.logs = "GAN_loss=0.138318;"
1236+
container.metric_definitions = [
1237+
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
1238+
]
1239+
job = _LocalTrainingJob(container)
1240+
job.end_time = None
1241+
1242+
mock_now = datetime(2023, 1, 1, 12, 0, 0)
1243+
mock_datetime.now.return_value = mock_now
1244+
1245+
result = job._extract_final_metrics()
1246+
1247+
assert len(result) == 1
1248+
assert result[0]["Timestamp"] == mock_now
1249+
1250+
@patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model")
1251+
def test_integration_describe_training_job_with_metrics(self, mock_train):
1252+
"""Integration test: describe_training_job includes FinalMetricDataList."""
1253+
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
1254+
1255+
algo_spec = {"TrainingImage": "my-image:1.0"}
1256+
input_data_config = [{
1257+
"ChannelName": "training",
1258+
"DataSource": {
1259+
"S3DataSource": {
1260+
"S3DataDistributionType": "FullyReplicated",
1261+
"S3Uri": "s3://bucket/data"
1262+
}
1263+
}
1264+
}]
1265+
output_data_config = {}
1266+
resource_config = {"InstanceType": "local", "InstanceCount": 1}
1267+
1268+
# Create training job
1269+
local_sagemaker_client.create_training_job(
1270+
"test-job",
1271+
algo_spec,
1272+
output_data_config,
1273+
resource_config,
1274+
InputDataConfig=input_data_config,
1275+
HyperParameters={}
1276+
)
1277+
1278+
# Mock the container logs and metric definitions
1279+
training_job = local_sagemaker_client._training_jobs["test-job"]
1280+
training_job.container.logs = "GAN_loss=0.138318;"
1281+
training_job.container.metric_definitions = [
1282+
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
1283+
]
1284+
1285+
response = local_sagemaker_client.describe_training_job("test-job")
1286+
1287+
assert "FinalMetricDataList" in response
1288+
assert len(response["FinalMetricDataList"]) == 1
1289+
assert response["FinalMetricDataList"][0]["MetricName"] == "ganloss"
1290+
assert response["FinalMetricDataList"][0]["Value"] == 0.138318

0 commit comments

Comments
 (0)