We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5362ed2 commit 030c5fcCopy full SHA for 030c5fc
pylops_mpi/utils/_nccl.py
@@ -107,9 +107,14 @@ def initialize_nccl_comm() -> nccl.NcclCommunicator:
107
comm = MPI.COMM_WORLD
108
rank = comm.Get_rank()
109
size = comm.Get_size()
110
+
111
+ # Create a communicator for ranks on the same node
112
+ node_comm = comm.Split_type(MPI.COMM_TYPE_SHARED)
113
+ ranks_on_node = node_comm.Get_size()
114
115
device_id = int(
116
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
- or rank % cp.cuda.runtime.getDeviceCount()
117
+ or (rank % ranks_on_node) % cp.cuda.runtime.getDeviceCount()
118
)
119
cp.cuda.Device(device_id).use()
120
0 commit comments