|
13 | 13 | """MPI Utils Unit Tests."""
|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
16 |
| -# import subprocess |
| 16 | +import subprocess |
17 | 17 | from unittest.mock import Mock, patch
|
18 | 18 |
|
19 | 19 | import paramiko
|
|
29 | 29 | with patch.dict("sys.modules", {"utils": mock_utils}):
|
30 | 30 | from sagemaker.modules.train.container_drivers.mpi_utils import (
|
31 | 31 | CustomHostKeyPolicy,
|
32 |
| - ) # _can_connect,; write_status_file_to_workers, |
| 32 | + _can_connect, |
| 33 | + write_status_file_to_workers, |
| 34 | + ) |
33 | 35 |
|
34 | 36 | TEST_HOST = "algo-1"
|
35 | 37 | TEST_WORKER = "algo-2"
|
@@ -62,52 +64,52 @@ def test_custom_host_key_policy_invalid_hostname():
|
62 | 64 | mock_client.get_host_keys.assert_not_called()
|
63 | 65 |
|
64 | 66 |
|
65 |
| -# @patch("paramiko.SSHClient") |
66 |
| -# @patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") |
67 |
| -# def test_can_connect_success(mock_logger, mock_ssh_client): |
68 |
| -# """Test successful SSH connection.""" |
69 |
| -# mock_client = Mock() |
70 |
| -# mock_ssh_client.return_value.__enter__.return_value = mock_client |
71 |
| -# mock_client.connect.return_value = None # Successful connection |
| 67 | +@patch("paramiko.SSHClient") |
| 68 | +@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") |
| 69 | +def test_can_connect_success(mock_logger, mock_ssh_client): |
| 70 | + """Test successful SSH connection.""" |
| 71 | + mock_client = Mock() |
| 72 | + mock_ssh_client.return_value.__enter__.return_value = mock_client |
| 73 | + mock_client.connect.return_value = None # Successful connection |
72 | 74 |
|
73 |
| -# result = _can_connect(TEST_HOST) |
| 75 | + result = _can_connect(TEST_HOST) |
74 | 76 |
|
75 |
| -# assert result is True |
76 |
| -# mock_client.load_system_host_keys.assert_called_once() |
77 |
| -# mock_client.set_missing_host_key_policy.assert_called_once() |
78 |
| -# mock_client.connect.assert_called_once_with(TEST_HOST, port=22) |
79 |
| -# mock_logger.info.assert_called_with("Can connect to host %s", TEST_HOST) |
| 77 | + assert result is True |
| 78 | + mock_client.load_system_host_keys.assert_called_once() |
| 79 | + mock_client.set_missing_host_key_policy.assert_called_once() |
| 80 | + mock_client.connect.assert_called_once_with(TEST_HOST, port=22) |
| 81 | + mock_logger.info.assert_called_with("Can connect to host %s", TEST_HOST) |
80 | 82 |
|
81 | 83 |
|
82 |
| -# @patch("paramiko.SSHClient") |
83 |
| -# @patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") |
84 |
| -# def test_can_connect_failure(mock_logger, mock_ssh_client): |
85 |
| -# """Test SSH connection failure.""" |
86 |
| -# mock_client = Mock() |
87 |
| -# mock_ssh_client.return_value.__enter__.return_value = mock_client |
88 |
| -# mock_client.connect.side_effect = paramiko.SSHException("Connection failed") |
| 84 | +@patch("paramiko.SSHClient") |
| 85 | +@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") |
| 86 | +def test_can_connect_failure(mock_logger, mock_ssh_client): |
| 87 | + """Test SSH connection failure.""" |
| 88 | + mock_client = Mock() |
| 89 | + mock_ssh_client.return_value.__enter__.return_value = mock_client |
| 90 | + mock_client.connect.side_effect = paramiko.SSHException("Connection failed") |
89 | 91 |
|
90 |
| -# result = _can_connect(TEST_HOST) |
| 92 | + result = _can_connect(TEST_HOST) |
91 | 93 |
|
92 |
| -# assert result is False |
93 |
| -# mock_client.load_system_host_keys.assert_called_once() |
94 |
| -# mock_client.set_missing_host_key_policy.assert_called_once() |
95 |
| -# mock_client.connect.assert_called_once_with(TEST_HOST, port=22) |
96 |
| -# mock_logger.info.assert_called_with("Cannot connect to host %s", TEST_HOST) |
| 94 | + assert result is False |
| 95 | + mock_client.load_system_host_keys.assert_called_once() |
| 96 | + mock_client.set_missing_host_key_policy.assert_called_once() |
| 97 | + mock_client.connect.assert_called_once_with(TEST_HOST, port=22) |
| 98 | + mock_logger.info.assert_called_with("Cannot connect to host %s", TEST_HOST) |
97 | 99 |
|
98 | 100 |
|
99 |
| -# @patch("subprocess.run") |
100 |
| -# @patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") |
101 |
| -# def test_write_status_file_to_workers_failure(mock_logger, mock_run): |
102 |
| -# """Test failed status file writing to workers with retry timeout.""" |
103 |
| -# mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") |
| 101 | +@patch("subprocess.run") |
| 102 | +@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") |
| 103 | +def test_write_status_file_to_workers_failure(mock_logger, mock_run): |
| 104 | + """Test failed status file writing to workers with retry timeout.""" |
| 105 | + mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") |
104 | 106 |
|
105 |
| -# with pytest.raises(TimeoutError) as exc_info: |
106 |
| -# write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE) |
| 107 | + with pytest.raises(TimeoutError) as exc_info: |
| 108 | + write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE) |
107 | 109 |
|
108 |
| -# assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value) |
109 |
| -# assert mock_run.call_count > 1 # Verifies that retries occurred |
110 |
| -# mock_logger.info.assert_any_call(f"Cannot connect to {TEST_WORKER}") |
| 110 | + assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value) |
| 111 | + assert mock_run.call_count > 1 # Verifies that retries occurred |
| 112 | + mock_logger.info.assert_any_call(f"Cannot connect to {TEST_WORKER}") |
111 | 113 |
|
112 | 114 |
|
113 | 115 | if __name__ == "__main__":
|
|
0 commit comments