Skip to content

Commit 155a6b8

Browse files
committed
change: add unit tests and remove unreachable condition
1 parent e42af65 commit 155a6b8

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,6 @@ def model_post_init(self, __context: Any):
483483
except json.JSONDecodeError:
484484
try:
485485
self.hyperparameters = yaml.safe_load(contents)
486-
if not isinstance(self.hyperparameters, dict):
487-
raise ValueError("YAML content is not a valid mapping.")
488486
logger.debug("Hyperparameters loaded as YAML")
489487
except (yaml.YAMLError, ValueError):
490488
raise ValueError(

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
import tempfile
1818
import json
1919
import os
20+
import yaml
2021
import pytest
2122
from pydantic import ValidationError
22-
from unittest.mock import patch, MagicMock, ANY
23+
from unittest.mock import patch, MagicMock, ANY, mock_open
2324

2425
from sagemaker import image_uris
2526
from sagemaker_core.main.resources import TrainingJob
@@ -1093,3 +1094,68 @@ def test_destructor_cleanup(mock_tmp_dir, modules_session):
10931094
mock_tmp_dir.assert_not_called()
10941095
del model_trainer
10951096
mock_tmp_dir.cleanup.assert_called_once()
1097+
1098+
1099+
@patch("os.path.exists")
1100+
def test_hyperparameters_valid_json(mock_exists, modules_session):
1101+
mock_exists.return_value = True
1102+
expected_hyperparameters = {"param1": "value1", "param2": 2}
1103+
mock_file_open = mock_open(read_data=json.dumps(expected_hyperparameters))
1104+
1105+
with patch("builtins.open", mock_file_open):
1106+
model_trainer = ModelTrainer(
1107+
training_image=DEFAULT_IMAGE,
1108+
role=DEFAULT_ROLE,
1109+
sagemaker_session=modules_session,
1110+
compute=DEFAULT_COMPUTE_CONFIG,
1111+
hyperparameters="hyperparameters.json",
1112+
)
1113+
assert model_trainer.hyperparameters == expected_hyperparameters
1114+
mock_file_open.assert_called_once_with("hyperparameters.json", "r")
1115+
mock_exists.assert_called_once_with("hyperparameters.json")
1116+
1117+
1118+
@patch("os.path.exists")
1119+
def test_hyperparameters_valid_yaml(mock_exists, modules_session):
1120+
mock_exists.return_value = True
1121+
expected_hyperparameters = {"param1": "value1", "param2": 2}
1122+
mock_file_open = mock_open(read_data=yaml.dump(expected_hyperparameters))
1123+
1124+
with patch("builtins.open", mock_file_open):
1125+
model_trainer = ModelTrainer(
1126+
training_image=DEFAULT_IMAGE,
1127+
role=DEFAULT_ROLE,
1128+
sagemaker_session=modules_session,
1129+
compute=DEFAULT_COMPUTE_CONFIG,
1130+
hyperparameters="hyperparameters.yaml",
1131+
)
1132+
assert model_trainer.hyperparameters == expected_hyperparameters
1133+
mock_file_open.assert_called_once_with("hyperparameters.yaml", "r")
1134+
mock_exists.assert_called_once_with("hyperparameters.yaml")
1135+
1136+
1137+
def test_hyperparameters_not_exist(modules_session):
1138+
with pytest.raises(ValueError):
1139+
ModelTrainer(
1140+
training_image=DEFAULT_IMAGE,
1141+
role=DEFAULT_ROLE,
1142+
sagemaker_session=modules_session,
1143+
compute=DEFAULT_COMPUTE_CONFIG,
1144+
hyperparameters="nonexistent.json",
1145+
)
1146+
1147+
1148+
@patch("os.path.exists")
1149+
def test_hyperparameters_invalid(mock_exists, modules_session):
1150+
mock_exists.return_value = True
1151+
# Must be valid YAML or JSON
1152+
mock_file_open = mock_open(read_data="invalid")
1153+
with patch("builtins.open", mock_file_open):
1154+
with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."):
1155+
ModelTrainer(
1156+
training_image=DEFAULT_IMAGE,
1157+
role=DEFAULT_ROLE,
1158+
sagemaker_session=modules_session,
1159+
compute=DEFAULT_COMPUTE_CONFIG,
1160+
hyperparameters="hyperparameters.yaml",
1161+
)

0 commit comments

Comments
 (0)