Skip to content

Commit e27b82a

Browse files
committed
fix nccl subcommunicator bug with asarray()
1 parent 45bee36 commit e27b82a

File tree

3 files changed

+45
-3
lines changed

3 files changed

+45
-3
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,14 @@ def asarray(self, masked: bool = False):
380380
return self.local_array
381381

382382
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
383-
return nccl_asarray(self.sub_comm if masked else self.base_comm,
384-
self.local_array, self.local_shapes, self.axis)
383+
if masked:
384+
all_tuples = self._allgather_subcomm(self.local_shape).get()
385+
tuple_len = len(self.local_shape)
386+
local_shapes = [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
387+
else:
388+
local_shapes = self.local_shapes
389+
return nccl_asarray(self.sub_comm if masked else self.base_comm_nccl,
390+
self.local_array, local_shapes, self.axis)
385391
else:
386392
# Gather all the local arrays and apply concatenation.
387393
if masked:

pylops_mpi/utils/_nccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray:
149149
)
150150
if recv_buf is None:
151151
recv_buf = cp.zeros(
152-
MPI.COMM_WORLD.Get_size() * send_buf.size,
152+
nccl_comm.size() * send_buf.size,
153153
dtype=send_buf.dtype,
154154
)
155155
nccl_comm.allGather(

tests_nccl/test_distributedarray_nccl.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,42 @@ def test_distributed_norm_nccl(par):
307307
assert_allclose(arr.norm().get(), np.linalg.norm(par["x"].flatten()), rtol=1e-13)
308308

309309

310+
@pytest.mark.mpi(min_size=2)
311+
@pytest.mark.parametrize("par", [(par6), (par8)])
312+
def test_distributed_masked_nccl(par):
313+
"""Test Asarray with masked array"""
314+
# Number of subcommunicators
315+
if MPI.COMM_WORLD.Get_size() % 2 == 0:
316+
nsub = 2
317+
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
318+
nsub = 3
319+
else:
320+
pass
321+
subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub)
322+
mask = np.repeat(np.arange(nsub), subsize)
323+
324+
# Replicate x as required in masked arrays
325+
x_gpu = cp.asarray(par['x'])
326+
if par['axis'] != 0:
327+
x_gpu = cp.swapaxes(x_gpu, par['axis'], 0)
328+
for isub in range(1, nsub):
329+
x_gpu[(x_gpu.shape[0] // nsub) * isub:(x_gpu.shape[0] // nsub) * (isub + 1)] = x_gpu[:x_gpu.shape[0] // nsub]
330+
if par['axis'] != 0:
331+
x_gpu = np.swapaxes(x_gpu, 0, par['axis'])
332+
333+
arr = DistributedArray.to_dist(x=x_gpu, base_comm_nccl=nccl_comm, partition=par['partition'], mask=mask, axis=par['axis'])
334+
335+
# Global view
336+
xloc = arr.asarray()
337+
assert xloc.shape == x_gpu.shape
338+
339+
# Global masked view
340+
xmaskedloc = arr.asarray(masked=True)
341+
xmasked_shape = list(x_gpu.shape)
342+
xmasked_shape[par['axis']] = int(xmasked_shape[par['axis']] // nsub)
343+
assert xmaskedloc.shape == tuple(xmasked_shape)
344+
345+
310346
@pytest.mark.mpi(min_size=2)
311347
@pytest.mark.parametrize(
312348
"par1, par2", [(par6, par7), (par6b, par7b), (par8, par9), (par8b, par9b)]

0 commit comments

Comments
 (0)