|
18 | 18 | import json |
19 | 19 | import os |
20 | 20 | import pytest |
| 21 | +from pydantic import ValidationError |
21 | 22 | from unittest.mock import patch, MagicMock, ANY |
22 | 23 |
|
23 | 24 | from sagemaker import image_uris |
@@ -442,7 +443,7 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ |
442 | 443 | { |
443 | 444 | "source_code": DEFAULT_SOURCE_CODE, |
444 | 445 | "distributed": MPI( |
445 | | - custom_mpi_options=["-x", "VAR1", "-x", "VAR2"], |
| 446 | + mpi_additional_options=["-x", "VAR1", "-x", "VAR2"], |
446 | 447 | ), |
447 | 448 | "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( |
448 | 449 | driver_name="MPI", driver_script="mpi_driver.py" |
@@ -1059,3 +1060,36 @@ def mock_upload_data(path, bucket, key_prefix): |
1059 | 1060 | hyper_parameters=hyperparameters, |
1060 | 1061 | environment=environment, |
1061 | 1062 | ) |
| 1063 | + |
| 1064 | + |
| 1065 | +def test_safe_configs(): |
| 1066 | + # Test extra fails |
| 1067 | + with pytest.raises(ValueError): |
| 1068 | + SourceCode(entry_point="train.py") |
| 1069 | + # Test invalid type fails |
| 1070 | + with pytest.raises(ValueError): |
| 1071 | + SourceCode(entry_script=1) |
| 1072 | + |
| 1073 | + |
| 1074 | +@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory") |
| 1075 | +def test_destructor_cleanup(mock_tmp_dir, modules_session): |
| 1076 | + |
| 1077 | + with pytest.raises(ValidationError): |
| 1078 | + model_trainer = ModelTrainer( |
| 1079 | + training_image=DEFAULT_IMAGE, |
| 1080 | + role=DEFAULT_ROLE, |
| 1081 | + sagemaker_session=modules_session, |
| 1082 | + compute="test", |
| 1083 | + ) |
| 1084 | + mock_tmp_dir.cleanup.assert_not_called() |
| 1085 | + |
| 1086 | + model_trainer = ModelTrainer( |
| 1087 | + training_image=DEFAULT_IMAGE, |
| 1088 | + role=DEFAULT_ROLE, |
| 1089 | + sagemaker_session=modules_session, |
| 1090 | + compute=DEFAULT_COMPUTE_CONFIG, |
| 1091 | + ) |
| 1092 | + model_trainer._temp_recipe_train_dir = mock_tmp_dir |
| 1093 | + mock_tmp_dir.assert_not_called() |
| 1094 | + del model_trainer |
| 1095 | + mock_tmp_dir.cleanup.assert_called_once() |
0 commit comments