@@ -22,7 +22,11 @@ def set_environment_variables_pytest_single_process():
2222def set_environment_variables_pytest_multi_process (
2323 rank : int = 0 , world_size : int = 1
2424) -> None :
25- port = 29500 + random .randint (1 , 1000 )
25+ # Use existing MASTER_PORT if set, otherwise generate random one
26+ if "MASTER_PORT" not in os .environ :
27+ port = 29500 + random .randint (1 , 1000 )
28+ os .environ ["MASTER_PORT" ] = str (port )
29+
2630 # these variables are set by mpirun -n 2
2731 local_rank = int (
2832 os .environ .get ("OMPI_COMM_WORLD_LOCAL_RANK" , rank % torch .cuda .device_count ())
@@ -32,11 +36,15 @@ def set_environment_variables_pytest_multi_process(
3236 # Set up environment variable to run with mpirun
3337 os .environ ["RANK" ] = str (local_rank )
3438 os .environ ["WORLD_SIZE" ] = str (world_size )
35- os .environ ["MASTER_ADDR" ] = "127.0.0.1"
36- os .environ ["MASTER_PORT" ] = str (port )
37-
38- # Necessary to assign a device to each rank.
39- torch .cuda .set_device (local_rank )
39+ os .environ ["MASTER_ADDR" ] = os .environ .get ("MASTER_ADDR" , "127.0.0.1" )
40+
41+ # Takes into account 2 processes on 1 GPU
42+ num_gpus = torch .cuda .device_count ()
43+ if num_gpus > 0 :
44+ gpu_id = local_rank % num_gpus
45+ torch .cuda .set_device (gpu_id )
46+ else :
47+ raise RuntimeError ("No CUDA devices available for distributed testing" )
4048
4149 # We use nccl backend
4250 dist .init_process_group ("nccl" )
0 commit comments