@@ -77,23 +77,35 @@ def test_can_connect_failure(mock_ssh_client):
7777
7878def test_get_mpirun_command ():
7979 """Test MPI command generation."""
80- os .environ ["SM_NETWORK_INTERFACE_NAME" ] = "eth0"
81- os .environ ["SM_CURRENT_INSTANCE_TYPE" ] = "ml.p4d.24xlarge"
82-
83- command = get_mpirun_command (
84- host_count = 2 ,
85- host_list = ["algo-1" , "algo-2" ],
86- num_processes = 2 ,
87- additional_options = [],
88- entry_script_path = "train.py" ,
89- )
90-
91- assert command [0 ] == "mpirun"
92- assert "--host" in command
93- assert "algo-1,algo-2" in command
94- assert "-np" in command
95- assert "2" in command
96- assert f"NCCL_SOCKET_IFNAME=eth0" in " " .join (command )
80+ test_network_interface = "eth0"
81+ test_instance_type = "ml.p4d.24xlarge"
82+
83+ with patch .dict (
84+ os .environ ,
85+ {
86+ "SM_NETWORK_INTERFACE_NAME" : test_network_interface ,
87+ "SM_CURRENT_INSTANCE_TYPE" : test_instance_type ,
88+ },
89+ ):
90+ command = get_mpirun_command (
91+ host_count = 2 ,
92+ host_list = ["algo-1" , "algo-2" ],
93+ num_processes = 2 ,
94+ additional_options = [],
95+ entry_script_path = "train.py" ,
96+ )
97+
98+ # Basic command structure checks
99+ assert command [0 ] == "mpirun"
100+ assert "--host" in command
101+ assert "algo-1,algo-2" in command
102+ assert "-np" in command
103+ assert "2" in command
104+
105+ # Network interface check
106+ expected_nccl_config = f"NCCL_SOCKET_IFNAME={ test_network_interface } "
107+ command_str = " " .join (command )
108+ assert expected_nccl_config in command_str
97109
98110
99111@patch ("sagemaker.modules.train.container_drivers.mpi_utils._can_connect" )
0 commit comments