Skip to content

Commit 18e31ab

Browse files
authored
Merge pull request #142 from tharittk/nccl_asarray_bug
fix nccl subcommunicator bug with asarray()
2 parents 45bee36 + 287fcb3 commit 18e31ab

File tree

3 files changed

+54
-9
lines changed

3 files changed

+54
-9
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,7 @@ def local_shapes(self):
340340
local_shapes : :obj:`list`
341341
"""
342342
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
343-
# gather tuple of shapes from every rank and copy from GPU to CPU
344-
all_tuples = self._allgather(self.local_shape).get()
345-
# NCCL returns the flat array that packs every tuple as 1-dimensional array
346-
# unpack each tuple from each rank
347-
tuple_len = len(self.local_shape)
348-
return [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
343+
return self._nccl_local_shapes(False)
349344
else:
350345
return self._allgather(self.local_shape)
351346

@@ -380,8 +375,8 @@ def asarray(self, masked: bool = False):
380375
return self.local_array
381376

382377
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
383-
return nccl_asarray(self.sub_comm if masked else self.base_comm,
384-
self.local_array, self.local_shapes, self.axis)
378+
return nccl_asarray(self.sub_comm if masked else self.base_comm_nccl,
379+
self.local_array, self._nccl_local_shapes(masked), self.axis)
385380
else:
386381
# Gather all the local arrays and apply concatenation.
387382
if masked:
@@ -554,6 +549,20 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None):
554549
self.base_comm.Recv(buf=recv_buf, source=source, tag=tag)
555550
return recv_buf
556551

552+
def _nccl_local_shapes(self, masked: bool):
553+
"""Get the the list of shapes of every GPU in the communicator
554+
"""
555+
# gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU
556+
if masked:
557+
all_tuples = self._allgather_subcomm(self.local_shape).get()
558+
else:
559+
all_tuples = self._allgather(self.local_shape).get()
560+
# NCCL returns the flat array that packs every tuple as 1-dimensional array
561+
# unpack each tuple from each rank
562+
tuple_len = len(self.local_shape)
563+
local_shapes = [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
564+
return local_shapes
565+
557566
def __neg__(self):
558567
arr = DistributedArray(global_shape=self.global_shape,
559568
base_comm=self.base_comm,

pylops_mpi/utils/_nccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray:
149149
)
150150
if recv_buf is None:
151151
recv_buf = cp.zeros(
152-
MPI.COMM_WORLD.Get_size() * send_buf.size,
152+
nccl_comm.size() * send_buf.size,
153153
dtype=send_buf.dtype,
154154
)
155155
nccl_comm.allGather(

tests_nccl/test_distributedarray_nccl.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,42 @@ def test_distributed_norm_nccl(par):
307307
assert_allclose(arr.norm().get(), np.linalg.norm(par["x"].flatten()), rtol=1e-13)
308308

309309

310+
@pytest.mark.mpi(min_size=2)
311+
@pytest.mark.parametrize("par", [(par6), (par8)])
312+
def test_distributed_masked_nccl(par):
313+
"""Test Asarray with masked array"""
314+
# Number of subcommunicators
315+
if MPI.COMM_WORLD.Get_size() % 2 == 0:
316+
nsub = 2
317+
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
318+
nsub = 3
319+
else:
320+
pass
321+
subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub)
322+
mask = np.repeat(np.arange(nsub), subsize)
323+
324+
# Replicate x as required in masked arrays
325+
x_gpu = cp.asarray(par['x'])
326+
if par['axis'] != 0:
327+
x_gpu = cp.swapaxes(x_gpu, par['axis'], 0)
328+
for isub in range(1, nsub):
329+
x_gpu[(x_gpu.shape[0] // nsub) * isub:(x_gpu.shape[0] // nsub) * (isub + 1)] = x_gpu[:x_gpu.shape[0] // nsub]
330+
if par['axis'] != 0:
331+
x_gpu = np.swapaxes(x_gpu, 0, par['axis'])
332+
333+
arr = DistributedArray.to_dist(x=x_gpu, base_comm_nccl=nccl_comm, partition=par['partition'], mask=mask, axis=par['axis'])
334+
335+
# Global view
336+
xloc = arr.asarray()
337+
assert xloc.shape == x_gpu.shape
338+
339+
# Global masked view
340+
xmaskedloc = arr.asarray(masked=True)
341+
xmasked_shape = list(x_gpu.shape)
342+
xmasked_shape[par['axis']] = int(xmasked_shape[par['axis']] // nsub)
343+
assert xmaskedloc.shape == tuple(xmasked_shape)
344+
345+
310346
@pytest.mark.mpi(min_size=2)
311347
@pytest.mark.parametrize(
312348
"par1, par2", [(par6, par7), (par6b, par7b), (par8, par9), (par8b, par9b)]

0 commit comments

Comments
 (0)