1313"""MPI Utils Unit Tests."""
1414from __future__ import absolute_import
1515
16- import os
16+ import subprocess
1717from unittest .mock import Mock , patch
1818
1919import paramiko
2020import pytest
2121
22- from sagemaker .modules .train .container_drivers .mpi_utils import (
23- CustomHostKeyPolicy ,
24- _can_connect ,
25- bootstrap_master_node ,
26- bootstrap_worker_node ,
27- get_mpirun_command ,
28- write_status_file_to_workers ,
29- )
22+ # Mock the utils module before importing mpi_utils
23+ mock_utils = Mock ()
24+ mock_utils .logger = Mock ()
25+ mock_utils .SM_EFA_NCCL_INSTANCES = []
26+ mock_utils .SM_EFA_RDMA_INSTANCES = []
27+ mock_utils .get_python_executable = Mock (return_value = "/usr/bin/python" )
28+
29+ with patch .dict ("sys.modules" , {"utils" : mock_utils }):
30+ from sagemaker .modules .train .container_drivers .mpi_utils import (
31+ CustomHostKeyPolicy ,
32+ _can_connect ,
33+ write_status_file_to_workers ,
34+ )
3035
3136TEST_HOST = "algo-1"
3237TEST_WORKER = "algo-2"
3338TEST_STATUS_FILE = "/tmp/test-status"
3439
3540
3641def test_custom_host_key_policy_valid_hostname ():
37- """Test CustomHostKeyPolicy with valid algo- hostname ."""
42+ """Test CustomHostKeyPolicy accepts algo- prefixed hostnames ."""
3843 policy = CustomHostKeyPolicy ()
3944 mock_client = Mock ()
4045 mock_key = Mock ()
@@ -47,7 +52,7 @@ def test_custom_host_key_policy_valid_hostname():
4752
4853
4954def test_custom_host_key_policy_invalid_hostname ():
50- """Test CustomHostKeyPolicy with invalid hostname ."""
55+ """Test CustomHostKeyPolicy rejects non-algo prefixed hostnames ."""
5156 policy = CustomHostKeyPolicy ()
5257 mock_client = Mock ()
5358 mock_key = Mock ()
@@ -60,112 +65,51 @@ def test_custom_host_key_policy_invalid_hostname():
6065
6166
6267@patch ("paramiko.SSHClient" )
63- def test_can_connect_success (mock_ssh_client ):
68+ @patch ("sagemaker.modules.train.container_drivers.mpi_utils.logger" )
69+ def test_can_connect_success (mock_logger , mock_ssh_client ):
6470 """Test successful SSH connection."""
6571 mock_client = Mock ()
66- mock_ssh_client .return_value = mock_client
72+ mock_ssh_client .return_value .__enter__ .return_value = mock_client
73+ mock_client .connect .return_value = None # Successful connection
74+
75+ result = _can_connect (TEST_HOST )
6776
68- assert _can_connect (TEST_HOST ) is True
77+ assert result is True
78+ mock_client .load_system_host_keys .assert_called_once ()
79+ mock_client .set_missing_host_key_policy .assert_called_once ()
6980 mock_client .connect .assert_called_once_with (TEST_HOST , port = 22 )
81+ mock_logger .info .assert_called_with ("Can connect to host %s" , TEST_HOST )
7082
7183
7284@patch ("paramiko.SSHClient" )
73- def test_can_connect_failure (mock_ssh_client ):
85+ @patch ("sagemaker.modules.train.container_drivers.mpi_utils.logger" )
86+ def test_can_connect_failure (mock_logger , mock_ssh_client ):
7487 """Test SSH connection failure."""
7588 mock_client = Mock ()
76- mock_ssh_client .return_value = mock_client
77- mock_client .connect .side_effect = Exception ("Connection failed" )
78-
79- assert _can_connect (TEST_HOST ) is False
89+ mock_ssh_client .return_value .__enter__ .return_value = mock_client
90+ mock_client .connect .side_effect = paramiko .SSHException ("Connection failed" )
8091
92+ result = _can_connect (TEST_HOST )
8193
82- @patch ("subprocess.run" )
83- def test_write_status_file_to_workers_success (mock_run ):
84- """Test successful status file writing to workers."""
85- mock_run .return_value = Mock (returncode = 0 )
86-
87- write_status_file_to_workers ([TEST_WORKER ], TEST_STATUS_FILE )
88-
89- mock_run .assert_called_once ()
90- args = mock_run .call_args [0 ][0 ]
91- assert args == ["ssh" , TEST_WORKER , "touch" , TEST_STATUS_FILE ]
94+ assert result is False
95+ mock_client .load_system_host_keys .assert_called_once ()
96+ mock_client .set_missing_host_key_policy .assert_called_once ()
97+ mock_client .connect .assert_called_once_with (TEST_HOST , port = 22 )
98+ mock_logger .info .assert_called_with ("Cannot connect to host %s" , TEST_HOST )
9299
93100
94101@patch ("subprocess.run" )
95- def test_write_status_file_to_workers_failure (mock_run ):
102+ @patch ("sagemaker.modules.train.container_drivers.mpi_utils.logger" )
103+ def test_write_status_file_to_workers_failure (mock_logger , mock_run ):
96104 """Test failed status file writing to workers with retry timeout."""
97- mock_run .side_effect = Exception ( "SSH failed " )
105+ mock_run .side_effect = subprocess . CalledProcessError ( 1 , "ssh " )
98106
99107 with pytest .raises (TimeoutError ) as exc_info :
100108 write_status_file_to_workers ([TEST_WORKER ], TEST_STATUS_FILE )
101109
102110 assert f"Timed out waiting for { TEST_WORKER } " in str (exc_info .value )
103-
104-
105- def test_get_mpirun_command_basic ():
106- """Test basic MPI command generation."""
107- with patch .dict (
108- os .environ ,
109- {"SM_NETWORK_INTERFACE_NAME" : "eth0" , "SM_CURRENT_INSTANCE_TYPE" : "ml.p3.16xlarge" },
110- ):
111- command = get_mpirun_command (
112- host_count = 2 ,
113- host_list = [TEST_HOST , TEST_WORKER ],
114- num_processes = 2 ,
115- additional_options = [],
116- entry_script_path = "train.py" ,
117- )
118-
119- assert command [0 ] == "mpirun"
120- assert "--host" in command
121- assert f"{ TEST_HOST } ,{ TEST_WORKER } " in command
122- assert "-np" in command
123- assert "2" in command
124-
125-
126- def test_get_mpirun_command_efa ():
127- """Test MPI command generation with EFA instance."""
128- with patch .dict (
129- os .environ ,
130- {"SM_NETWORK_INTERFACE_NAME" : "eth0" , "SM_CURRENT_INSTANCE_TYPE" : "ml.p4d.24xlarge" },
131- ):
132- command = get_mpirun_command (
133- host_count = 2 ,
134- host_list = [TEST_HOST , TEST_WORKER ],
135- num_processes = 2 ,
136- additional_options = [],
137- entry_script_path = "train.py" ,
138- )
139-
140- command_str = " " .join (command )
141- assert "FI_PROVIDER=efa" in command_str
142- assert "NCCL_PROTO=simple" in command_str
143-
144-
145- @patch ("sagemaker.modules.train.container_drivers.mpi_utils._can_connect" )
146- @patch ("sagemaker.modules.train.container_drivers.mpi_utils._write_file_to_host" )
147- def test_bootstrap_worker_node (mock_write , mock_connect ):
148- """Test worker node bootstrap process."""
149- mock_connect .return_value = True
150- mock_write .return_value = True
151-
152- with patch .dict (os .environ , {"SM_CURRENT_HOST" : TEST_WORKER }):
153- with pytest .raises (TimeoutError ):
154- bootstrap_worker_node (TEST_HOST , timeout = 1 )
155-
156- mock_connect .assert_called_with (TEST_HOST )
157- mock_write .assert_called_with (TEST_HOST , f"/tmp/ready.{ TEST_WORKER } " )
158-
159-
160- @patch ("sagemaker.modules.train.container_drivers.mpi_utils._can_connect" )
161- def test_bootstrap_master_node (mock_connect ):
162- """Test master node bootstrap process."""
163- mock_connect .return_value = True
164-
165- with pytest .raises (TimeoutError ):
166- bootstrap_master_node ([TEST_WORKER ], timeout = 1 )
167-
168- mock_connect .assert_called_with (TEST_WORKER )
111+ assert mock_run .call_count > 1 # Verifies that retries occurred
112+ mock_logger .info .assert_any_call (f"Cannot connect to { TEST_WORKER } " )
169113
170114
171115if __name__ == "__main__" :
0 commit comments