Skip to content

Commit 030c5fc

Browse files
committed
feat: modify initialize_nccl_comm to handle nodes with more gpus than ranks
1 parent 5362ed2 commit 030c5fc

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

pylops_mpi/utils/_nccl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,14 @@ def initialize_nccl_comm() -> nccl.NcclCommunicator:
107107
comm = MPI.COMM_WORLD
108108
rank = comm.Get_rank()
109109
size = comm.Get_size()
110+
111+
# Create a communicator for ranks on the same node
112+
node_comm = comm.Split_type(MPI.COMM_TYPE_SHARED)
113+
ranks_on_node = node_comm.Get_size()
114+
110115
device_id = int(
111116
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
112-
or rank % cp.cuda.runtime.getDeviceCount()
117+
or (rank % ranks_on_node) % cp.cuda.runtime.getDeviceCount()
113118
)
114119
cp.cuda.Device(device_id).use()
115120

0 commit comments

Comments
 (0)