Skip to content

Commit 5b19a68

Browse files
committed
Readd the flaky tests
1 parent fbf0b9b commit 5b19a68

File tree

1 file changed

+40
-38
lines changed

1 file changed

+40
-38
lines changed

tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""MPI Utils Unit Tests."""
1414
from __future__ import absolute_import
1515

16-
# import subprocess
16+
import subprocess
1717
from unittest.mock import Mock, patch
1818

1919
import paramiko
@@ -29,7 +29,9 @@
2929
with patch.dict("sys.modules", {"utils": mock_utils}):
3030
from sagemaker.modules.train.container_drivers.mpi_utils import (
3131
CustomHostKeyPolicy,
32-
) # _can_connect,; write_status_file_to_workers,
32+
_can_connect,
33+
write_status_file_to_workers,
34+
)
3335

3436
TEST_HOST = "algo-1"
3537
TEST_WORKER = "algo-2"
@@ -62,52 +64,52 @@ def test_custom_host_key_policy_invalid_hostname():
6264
mock_client.get_host_keys.assert_not_called()
6365

6466

65-
# @patch("paramiko.SSHClient")
66-
# @patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
67-
# def test_can_connect_success(mock_logger, mock_ssh_client):
68-
# """Test successful SSH connection."""
69-
# mock_client = Mock()
70-
# mock_ssh_client.return_value.__enter__.return_value = mock_client
71-
# mock_client.connect.return_value = None # Successful connection
67+
@patch("paramiko.SSHClient")
68+
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
69+
def test_can_connect_success(mock_logger, mock_ssh_client):
70+
"""Test successful SSH connection."""
71+
mock_client = Mock()
72+
mock_ssh_client.return_value.__enter__.return_value = mock_client
73+
mock_client.connect.return_value = None # Successful connection
7274

73-
# result = _can_connect(TEST_HOST)
75+
result = _can_connect(TEST_HOST)
7476

75-
# assert result is True
76-
# mock_client.load_system_host_keys.assert_called_once()
77-
# mock_client.set_missing_host_key_policy.assert_called_once()
78-
# mock_client.connect.assert_called_once_with(TEST_HOST, port=22)
79-
# mock_logger.info.assert_called_with("Can connect to host %s", TEST_HOST)
77+
assert result is True
78+
mock_client.load_system_host_keys.assert_called_once()
79+
mock_client.set_missing_host_key_policy.assert_called_once()
80+
mock_client.connect.assert_called_once_with(TEST_HOST, port=22)
81+
mock_logger.info.assert_called_with("Can connect to host %s", TEST_HOST)
8082

8183

82-
# @patch("paramiko.SSHClient")
83-
# @patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
84-
# def test_can_connect_failure(mock_logger, mock_ssh_client):
85-
# """Test SSH connection failure."""
86-
# mock_client = Mock()
87-
# mock_ssh_client.return_value.__enter__.return_value = mock_client
88-
# mock_client.connect.side_effect = paramiko.SSHException("Connection failed")
84+
@patch("paramiko.SSHClient")
85+
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
86+
def test_can_connect_failure(mock_logger, mock_ssh_client):
87+
"""Test SSH connection failure."""
88+
mock_client = Mock()
89+
mock_ssh_client.return_value.__enter__.return_value = mock_client
90+
mock_client.connect.side_effect = paramiko.SSHException("Connection failed")
8991

90-
# result = _can_connect(TEST_HOST)
92+
result = _can_connect(TEST_HOST)
9193

92-
# assert result is False
93-
# mock_client.load_system_host_keys.assert_called_once()
94-
# mock_client.set_missing_host_key_policy.assert_called_once()
95-
# mock_client.connect.assert_called_once_with(TEST_HOST, port=22)
96-
# mock_logger.info.assert_called_with("Cannot connect to host %s", TEST_HOST)
94+
assert result is False
95+
mock_client.load_system_host_keys.assert_called_once()
96+
mock_client.set_missing_host_key_policy.assert_called_once()
97+
mock_client.connect.assert_called_once_with(TEST_HOST, port=22)
98+
mock_logger.info.assert_called_with("Cannot connect to host %s", TEST_HOST)
9799

98100

99-
# @patch("subprocess.run")
100-
# @patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
101-
# def test_write_status_file_to_workers_failure(mock_logger, mock_run):
102-
# """Test failed status file writing to workers with retry timeout."""
103-
# mock_run.side_effect = subprocess.CalledProcessError(1, "ssh")
101+
@patch("subprocess.run")
102+
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
103+
def test_write_status_file_to_workers_failure(mock_logger, mock_run):
104+
"""Test failed status file writing to workers with retry timeout."""
105+
mock_run.side_effect = subprocess.CalledProcessError(1, "ssh")
104106

105-
# with pytest.raises(TimeoutError) as exc_info:
106-
# write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE)
107+
with pytest.raises(TimeoutError) as exc_info:
108+
write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE)
107109

108-
# assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value)
109-
# assert mock_run.call_count > 1 # Verifies that retries occurred
110-
# mock_logger.info.assert_any_call(f"Cannot connect to {TEST_WORKER}")
110+
assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value)
111+
assert mock_run.call_count > 1 # Verifies that retries occurred
112+
mock_logger.info.assert_any_call(f"Cannot connect to {TEST_WORKER}")
111113

112114

113115
if __name__ == "__main__":

0 commit comments

Comments
 (0)