Skip to content

Commit 5d43462

Browse files
committed
tests: add unit test for local training job describe
1 parent b1f65b8 commit 5d43462

File tree

1 file changed

+206
-0
lines changed

1 file changed

+206
-0
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from datetime import datetime
17+
from mock import Mock, patch
18+
19+
from sagemaker.local.entities import _LocalTrainingJob
20+
21+
22+
class TestLocalTrainingJobFinalMetrics:
23+
"""Test cases for FinalMetricDataList functionality in _LocalTrainingJob."""
24+
25+
def test_describe_includes_final_metric_data_list(self):
26+
"""Test that describe() includes FinalMetricDataList field."""
27+
container = Mock()
28+
job = _LocalTrainingJob(container)
29+
job.training_job_name = "test-job"
30+
job.state = "Completed"
31+
job.start_time = datetime.now()
32+
job.end_time = datetime.now()
33+
job.model_artifacts = "/path/to/model"
34+
job.output_data_config = {}
35+
job.environment = {}
36+
37+
response = job.describe()
38+
39+
assert "FinalMetricDataList" in response
40+
assert isinstance(response["FinalMetricDataList"], list)
41+
42+
def test_extract_final_metrics_no_logs(self):
43+
"""Test _extract_final_metrics returns empty list when no logs."""
44+
container = Mock()
45+
container.logs = None
46+
job = _LocalTrainingJob(container)
47+
48+
result = job._extract_final_metrics()
49+
50+
assert result == []
51+
52+
def test_extract_final_metrics_no_metric_definitions(self):
53+
"""Test _extract_final_metrics returns empty list when no metric definitions."""
54+
container = Mock()
55+
container.logs = "some logs"
56+
container.metric_definitions = []
57+
job = _LocalTrainingJob(container)
58+
59+
result = job._extract_final_metrics()
60+
61+
assert result == []
62+
63+
def test_extract_final_metrics_with_valid_metrics(self):
64+
"""Test _extract_final_metrics extracts metrics correctly."""
65+
container = Mock()
66+
container.logs = "Training started\nGAN_loss=0.138318;\nTraining complete"
67+
container.metric_definitions = [
68+
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
69+
]
70+
job = _LocalTrainingJob(container)
71+
job.end_time = datetime(2023, 1, 1, 12, 0, 0)
72+
73+
result = job._extract_final_metrics()
74+
75+
assert len(result) == 1
76+
assert result[0]["MetricName"] == "ganloss"
77+
assert result[0]["Value"] == 0.138318
78+
assert result[0]["Timestamp"] == job.end_time
79+
80+
def test_extract_final_metrics_multiple_matches_uses_last(self):
81+
"""Test _extract_final_metrics uses the last match for each metric."""
82+
container = Mock()
83+
container.logs = "GAN_loss=0.5;\nGAN_loss=0.3;\nGAN_loss=0.138318;"
84+
container.metric_definitions = [
85+
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
86+
]
87+
job = _LocalTrainingJob(container)
88+
job.end_time = datetime(2023, 1, 1, 12, 0, 0)
89+
90+
result = job._extract_final_metrics()
91+
92+
assert len(result) == 1
93+
assert result[0]["Value"] == 0.138318
94+
95+
def test_extract_final_metrics_multiple_metrics(self):
96+
"""Test _extract_final_metrics handles multiple different metrics."""
97+
container = Mock()
98+
container.logs = "GAN_loss=0.138318;\nAccuracy=0.95;\nLoss=1.234;"
99+
container.metric_definitions = [
100+
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"},
101+
{"Name": "accuracy", "Regex": r"Accuracy=([\d\.]+);"},
102+
{"Name": "loss", "Regex": r"Loss=([\d\.]+);"}
103+
]
104+
job = _LocalTrainingJob(container)
105+
job.end_time = datetime(2023, 1, 1, 12, 0, 0)
106+
107+
result = job._extract_final_metrics()
108+
109+
assert len(result) == 3
110+
metric_names = [m["MetricName"] for m in result]
111+
assert "ganloss" in metric_names
112+
assert "accuracy" in metric_names
113+
assert "loss" in metric_names
114+
115+
def test_extract_final_metrics_no_matches(self):
116+
"""Test _extract_final_metrics returns empty list when regex doesn't match."""
117+
container = Mock()
118+
container.logs = "Training started\nTraining complete"
119+
container.metric_definitions = [
120+
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
121+
]
122+
job = _LocalTrainingJob(container)
123+
124+
result = job._extract_final_metrics()
125+
126+
assert result == []
127+
128+
def test_extract_final_metrics_invalid_metric_definition(self):
129+
"""Test _extract_final_metrics skips invalid metric definitions."""
130+
container = Mock()
131+
container.logs = "GAN_loss=0.138318;"
132+
container.metric_definitions = [
133+
{"Name": "ganloss"}, # Missing Regex
134+
{"Regex": r"GAN_loss=([\d\.]+);"}, # Missing Name
135+
{"Name": "valid", "Regex": r"GAN_loss=([\d\.]+);"} # Valid
136+
]
137+
job = _LocalTrainingJob(container)
138+
job.end_time = datetime(2023, 1, 1, 12, 0, 0)
139+
140+
result = job._extract_final_metrics()
141+
142+
assert len(result) == 1
143+
assert result[0]["MetricName"] == "valid"
144+
145+
@patch("sagemaker.local.entities.datetime")
146+
def test_extract_final_metrics_uses_current_time_when_no_end_time(self, mock_datetime):
147+
"""Test _extract_final_metrics uses current time when end_time is None."""
148+
container = Mock()
149+
container.logs = "GAN_loss=0.138318;"
150+
container.metric_definitions = [
151+
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
152+
]
153+
job = _LocalTrainingJob(container)
154+
job.end_time = None
155+
156+
mock_now = datetime(2023, 1, 1, 12, 0, 0)
157+
mock_datetime.now.return_value = mock_now
158+
159+
result = job._extract_final_metrics()
160+
161+
assert len(result) == 1
162+
assert result[0]["Timestamp"] == mock_now
163+
164+
@patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model")
165+
def test_integration_describe_training_job_with_metrics(self, mock_train):
166+
"""Integration test: describe_training_job includes FinalMetricDataList."""
167+
from sagemaker.local.local_session import LocalSagemakerClient
168+
169+
local_sagemaker_client = LocalSagemakerClient()
170+
171+
algo_spec = {"TrainingImage": "my-image:1.0"}
172+
input_data_config = [{
173+
"ChannelName": "training",
174+
"DataSource": {
175+
"S3DataSource": {
176+
"S3DataDistributionType": "FullyReplicated",
177+
"S3Uri": "s3://bucket/data"
178+
}
179+
}
180+
}]
181+
output_data_config = {}
182+
resource_config = {"InstanceType": "local", "InstanceCount": 1}
183+
184+
# Create training job
185+
local_sagemaker_client.create_training_job(
186+
"test-job",
187+
algo_spec,
188+
output_data_config,
189+
resource_config,
190+
InputDataConfig=input_data_config,
191+
HyperParameters={}
192+
)
193+
194+
# Mock the container logs and metric definitions
195+
training_job = local_sagemaker_client._training_jobs["test-job"]
196+
training_job.container.logs = "GAN_loss=0.138318;"
197+
training_job.container.metric_definitions = [
198+
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
199+
]
200+
201+
response = local_sagemaker_client.describe_training_job("test-job")
202+
203+
assert "FinalMetricDataList" in response
204+
assert len(response["FinalMetricDataList"]) == 1
205+
assert response["FinalMetricDataList"][0]["MetricName"] == "ganloss"
206+
assert response["FinalMetricDataList"][0]["Value"] == 0.138318

0 commit comments

Comments
 (0)