|
13 | 13 | """MPI Utils Unit Tests.""" |
14 | 14 | from __future__ import absolute_import |
15 | 15 |
|
16 | | -# import subprocess |
| 16 | +import subprocess |
17 | 17 | from unittest.mock import Mock, patch |
18 | 18 |
|
19 | 19 | import paramiko |
|
29 | 29 | with patch.dict("sys.modules", {"utils": mock_utils}): |
30 | 30 | from sagemaker.modules.train.container_drivers.mpi_utils import ( |
31 | 31 | CustomHostKeyPolicy, |
32 | | - ) # _can_connect,; write_status_file_to_workers, |
| 32 | + _can_connect, |
| 33 | + write_status_file_to_workers, |
| 34 | + ) |
33 | 35 |
|
34 | 36 | TEST_HOST = "algo-1" |
35 | 37 | TEST_WORKER = "algo-2" |
@@ -62,52 +64,52 @@ def test_custom_host_key_policy_invalid_hostname(): |
62 | 64 | mock_client.get_host_keys.assert_not_called() |
63 | 65 |
|
64 | 66 |
|
65 | | -# @patch("paramiko.SSHClient") |
66 | | -# @patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") |
67 | | -# def test_can_connect_success(mock_logger, mock_ssh_client): |
68 | | -# """Test successful SSH connection.""" |
69 | | -# mock_client = Mock() |
70 | | -# mock_ssh_client.return_value.__enter__.return_value = mock_client |
71 | | -# mock_client.connect.return_value = None # Successful connection |
| 67 | +@patch("paramiko.SSHClient") |
| 68 | +@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") |
| 69 | +def test_can_connect_success(mock_logger, mock_ssh_client): |
| 70 | + """Test successful SSH connection.""" |
| 71 | + mock_client = Mock() |
| 72 | + mock_ssh_client.return_value.__enter__.return_value = mock_client |
| 73 | + mock_client.connect.return_value = None # Successful connection |
72 | 74 |
|
73 | | -# result = _can_connect(TEST_HOST) |
| 75 | + result = _can_connect(TEST_HOST) |
74 | 76 |
|
75 | | -# assert result is True |
76 | | -# mock_client.load_system_host_keys.assert_called_once() |
77 | | -# mock_client.set_missing_host_key_policy.assert_called_once() |
78 | | -# mock_client.connect.assert_called_once_with(TEST_HOST, port=22) |
79 | | -# mock_logger.info.assert_called_with("Can connect to host %s", TEST_HOST) |
| 77 | + assert result is True |
| 78 | + mock_client.load_system_host_keys.assert_called_once() |
| 79 | + mock_client.set_missing_host_key_policy.assert_called_once() |
| 80 | + mock_client.connect.assert_called_once_with(TEST_HOST, port=22) |
| 81 | + mock_logger.info.assert_called_with("Can connect to host %s", TEST_HOST) |
80 | 82 |
|
81 | 83 |
|
82 | | -# @patch("paramiko.SSHClient") |
83 | | -# @patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") |
84 | | -# def test_can_connect_failure(mock_logger, mock_ssh_client): |
85 | | -# """Test SSH connection failure.""" |
86 | | -# mock_client = Mock() |
87 | | -# mock_ssh_client.return_value.__enter__.return_value = mock_client |
88 | | -# mock_client.connect.side_effect = paramiko.SSHException("Connection failed") |
| 84 | +@patch("paramiko.SSHClient") |
| 85 | +@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") |
| 86 | +def test_can_connect_failure(mock_logger, mock_ssh_client): |
| 87 | + """Test SSH connection failure.""" |
| 88 | + mock_client = Mock() |
| 89 | + mock_ssh_client.return_value.__enter__.return_value = mock_client |
| 90 | + mock_client.connect.side_effect = paramiko.SSHException("Connection failed") |
89 | 91 |
|
90 | | -# result = _can_connect(TEST_HOST) |
| 92 | + result = _can_connect(TEST_HOST) |
91 | 93 |
|
92 | | -# assert result is False |
93 | | -# mock_client.load_system_host_keys.assert_called_once() |
94 | | -# mock_client.set_missing_host_key_policy.assert_called_once() |
95 | | -# mock_client.connect.assert_called_once_with(TEST_HOST, port=22) |
96 | | -# mock_logger.info.assert_called_with("Cannot connect to host %s", TEST_HOST) |
| 94 | + assert result is False |
| 95 | + mock_client.load_system_host_keys.assert_called_once() |
| 96 | + mock_client.set_missing_host_key_policy.assert_called_once() |
| 97 | + mock_client.connect.assert_called_once_with(TEST_HOST, port=22) |
| 98 | + mock_logger.info.assert_called_with("Cannot connect to host %s", TEST_HOST) |
97 | 99 |
|
98 | 100 |
|
99 | | -# @patch("subprocess.run") |
100 | | -# @patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") |
101 | | -# def test_write_status_file_to_workers_failure(mock_logger, mock_run): |
102 | | -# """Test failed status file writing to workers with retry timeout.""" |
103 | | -# mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") |
| 101 | +@patch("subprocess.run") |
| 102 | +@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") |
| 103 | +def test_write_status_file_to_workers_failure(mock_logger, mock_run): |
| 104 | + """Test failed status file writing to workers with retry timeout.""" |
| 105 | + mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") |
104 | 106 |
|
105 | | -# with pytest.raises(TimeoutError) as exc_info: |
106 | | -# write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE) |
| 107 | + with pytest.raises(TimeoutError) as exc_info: |
| 108 | + write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE) |
107 | 109 |
|
108 | | -# assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value) |
109 | | -# assert mock_run.call_count > 1 # Verifies that retries occurred |
110 | | -# mock_logger.info.assert_any_call(f"Cannot connect to {TEST_WORKER}") |
| 110 | + assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value) |
| 111 | + assert mock_run.call_count > 1 # Verifies that retries occurred |
| 112 | + mock_logger.info.assert_any_call(f"Cannot connect to {TEST_WORKER}") |
111 | 113 |
|
112 | 114 |
|
113 | 115 | if __name__ == "__main__": |
|
0 commit comments