Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,14 @@ def asarray(self, masked: bool = False):
return self.local_array

if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
return nccl_asarray(self.sub_comm if masked else self.base_comm,
self.local_array, self.local_shapes, self.axis)
if masked:
all_tuples = self._allgather_subcomm(self.local_shape).get()
tuple_len = len(self.local_shape)
local_shapes = [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
else:
local_shapes = self.local_shapes
return nccl_asarray(self.sub_comm if masked else self.base_comm_nccl,
self.local_array, local_shapes, self.axis)
else:
# Gather all the local arrays and apply concatenation.
if masked:
Expand Down
2 changes: 1 addition & 1 deletion pylops_mpi/utils/_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray:
)
if recv_buf is None:
recv_buf = cp.zeros(
MPI.COMM_WORLD.Get_size() * send_buf.size,
nccl_comm.size() * send_buf.size,
dtype=send_buf.dtype,
)
nccl_comm.allGather(
Expand Down
36 changes: 36 additions & 0 deletions tests_nccl/test_distributedarray_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,42 @@ def test_distributed_norm_nccl(par):
assert_allclose(arr.norm().get(), np.linalg.norm(par["x"].flatten()), rtol=1e-13)


@pytest.mark.mpi(min_size=2)
@pytest.mark.parametrize("par", [(par6), (par8)])
def test_distributed_masked_nccl(par):
"""Test Asarray with masked array"""
# Number of subcommunicators
if MPI.COMM_WORLD.Get_size() % 2 == 0:
nsub = 2
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
nsub = 3
else:
pass
subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub)
mask = np.repeat(np.arange(nsub), subsize)

# Replicate x as required in masked arrays
x_gpu = cp.asarray(par['x'])
if par['axis'] != 0:
x_gpu = cp.swapaxes(x_gpu, par['axis'], 0)
for isub in range(1, nsub):
x_gpu[(x_gpu.shape[0] // nsub) * isub:(x_gpu.shape[0] // nsub) * (isub + 1)] = x_gpu[:x_gpu.shape[0] // nsub]
if par['axis'] != 0:
x_gpu = np.swapaxes(x_gpu, 0, par['axis'])

arr = DistributedArray.to_dist(x=x_gpu, base_comm_nccl=nccl_comm, partition=par['partition'], mask=mask, axis=par['axis'])

# Global view
xloc = arr.asarray()
assert xloc.shape == x_gpu.shape

# Global masked view
xmaskedloc = arr.asarray(masked=True)
xmasked_shape = list(x_gpu.shape)
xmasked_shape[par['axis']] = int(xmasked_shape[par['axis']] // nsub)
assert xmaskedloc.shape == tuple(xmasked_shape)


@pytest.mark.mpi(min_size=2)
@pytest.mark.parametrize(
"par1, par2", [(par6, par7), (par6b, par7b), (par8, par9), (par8b, par9b)]
Expand Down