|
17 | 17 | import tempfile
|
18 | 18 | import json
|
19 | 19 | import os
|
| 20 | +import yaml |
20 | 21 | import pytest
|
21 | 22 | from pydantic import ValidationError
|
22 |
| -from unittest.mock import patch, MagicMock, ANY |
| 23 | +from unittest.mock import patch, MagicMock, ANY, mock_open |
23 | 24 |
|
24 | 25 | from sagemaker import image_uris
|
25 | 26 | from sagemaker_core.main.resources import TrainingJob
|
@@ -1093,3 +1094,68 @@ def test_destructor_cleanup(mock_tmp_dir, modules_session):
|
1093 | 1094 | mock_tmp_dir.assert_not_called()
|
1094 | 1095 | del model_trainer
|
1095 | 1096 | mock_tmp_dir.cleanup.assert_called_once()
|
| 1097 | + |
| 1098 | + |
| 1099 | +@patch("os.path.exists") |
| 1100 | +def test_hyperparameters_valid_json(mock_exists, modules_session): |
| 1101 | + mock_exists.return_value = True |
| 1102 | + expected_hyperparameters = {"param1": "value1", "param2": 2} |
| 1103 | + mock_file_open = mock_open(read_data=json.dumps(expected_hyperparameters)) |
| 1104 | + |
| 1105 | + with patch("builtins.open", mock_file_open): |
| 1106 | + model_trainer = ModelTrainer( |
| 1107 | + training_image=DEFAULT_IMAGE, |
| 1108 | + role=DEFAULT_ROLE, |
| 1109 | + sagemaker_session=modules_session, |
| 1110 | + compute=DEFAULT_COMPUTE_CONFIG, |
| 1111 | + hyperparameters="hyperparameters.json", |
| 1112 | + ) |
| 1113 | + assert model_trainer.hyperparameters == expected_hyperparameters |
| 1114 | + mock_file_open.assert_called_once_with("hyperparameters.json", "r") |
| 1115 | + mock_exists.assert_called_once_with("hyperparameters.json") |
| 1116 | + |
| 1117 | + |
| 1118 | +@patch("os.path.exists") |
| 1119 | +def test_hyperparameters_valid_yaml(mock_exists, modules_session): |
| 1120 | + mock_exists.return_value = True |
| 1121 | + expected_hyperparameters = {"param1": "value1", "param2": 2} |
| 1122 | + mock_file_open = mock_open(read_data=yaml.dump(expected_hyperparameters)) |
| 1123 | + |
| 1124 | + with patch("builtins.open", mock_file_open): |
| 1125 | + model_trainer = ModelTrainer( |
| 1126 | + training_image=DEFAULT_IMAGE, |
| 1127 | + role=DEFAULT_ROLE, |
| 1128 | + sagemaker_session=modules_session, |
| 1129 | + compute=DEFAULT_COMPUTE_CONFIG, |
| 1130 | + hyperparameters="hyperparameters.yaml", |
| 1131 | + ) |
| 1132 | + assert model_trainer.hyperparameters == expected_hyperparameters |
| 1133 | + mock_file_open.assert_called_once_with("hyperparameters.yaml", "r") |
| 1134 | + mock_exists.assert_called_once_with("hyperparameters.yaml") |
| 1135 | + |
| 1136 | + |
| 1137 | +def test_hyperparameters_not_exist(modules_session): |
| 1138 | + with pytest.raises(ValueError): |
| 1139 | + ModelTrainer( |
| 1140 | + training_image=DEFAULT_IMAGE, |
| 1141 | + role=DEFAULT_ROLE, |
| 1142 | + sagemaker_session=modules_session, |
| 1143 | + compute=DEFAULT_COMPUTE_CONFIG, |
| 1144 | + hyperparameters="nonexistent.json", |
| 1145 | + ) |
| 1146 | + |
| 1147 | + |
| 1148 | +@patch("os.path.exists") |
| 1149 | +def test_hyperparameters_invalid(mock_exists, modules_session): |
| 1150 | + mock_exists.return_value = True |
| 1151 | + # Must be valid YAML or JSON |
| 1152 | + mock_file_open = mock_open(read_data="invalid") |
| 1153 | + with patch("builtins.open", mock_file_open): |
| 1154 | + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): |
| 1155 | + ModelTrainer( |
| 1156 | + training_image=DEFAULT_IMAGE, |
| 1157 | + role=DEFAULT_ROLE, |
| 1158 | + sagemaker_session=modules_session, |
| 1159 | + compute=DEFAULT_COMPUTE_CONFIG, |
| 1160 | + hyperparameters="hyperparameters.yaml", |
| 1161 | + ) |
0 commit comments