Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,7 @@ def local_shapes(self):
local_shapes : :obj:`list`
"""
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
# gather tuple of shapes from every rank and copy from GPU to CPU
all_tuples = self._allgather(self.local_shape).get()
# NCCL returns the flat array that packs every tuple as 1-dimensional array
# unpack each tuple from each rank
tuple_len = len(self.local_shape)
return [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
return self._nccl_local_shapes(False)
else:
return self._allgather(self.local_shape)

Expand Down Expand Up @@ -380,8 +375,8 @@ def asarray(self, masked: bool = False):
return self.local_array

if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
return nccl_asarray(self.sub_comm if masked else self.base_comm,
self.local_array, self.local_shapes, self.axis)
return nccl_asarray(self.sub_comm if masked else self.base_comm_nccl,
self.local_array, self._nccl_local_shapes(masked), self.axis)
else:
# Gather all the local arrays and apply concatenation.
if masked:
Expand Down Expand Up @@ -554,6 +549,20 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None):
self.base_comm.Recv(buf=recv_buf, source=source, tag=tag)
return recv_buf

def _nccl_local_shapes(self, masked: bool):
"""Get the the list of shapes of every GPU in the communicator
"""
# gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU
if masked:
all_tuples = self._allgather_subcomm(self.local_shape).get()
else:
all_tuples = self._allgather(self.local_shape).get()
# NCCL returns the flat array that packs every tuple as 1-dimensional array
# unpack each tuple from each rank
tuple_len = len(self.local_shape)
local_shapes = [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
return local_shapes

def __neg__(self):
arr = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
Expand Down
2 changes: 1 addition & 1 deletion pylops_mpi/utils/_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray:
)
if recv_buf is None:
recv_buf = cp.zeros(
MPI.COMM_WORLD.Get_size() * send_buf.size,
nccl_comm.size() * send_buf.size,
dtype=send_buf.dtype,
)
nccl_comm.allGather(
Expand Down
36 changes: 36 additions & 0 deletions tests_nccl/test_distributedarray_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,42 @@ def test_distributed_norm_nccl(par):
assert_allclose(arr.norm().get(), np.linalg.norm(par["x"].flatten()), rtol=1e-13)


@pytest.mark.mpi(min_size=2)
@pytest.mark.parametrize("par", [(par6), (par8)])
def test_distributed_masked_nccl(par):
"""Test Asarray with masked array"""
# Number of subcommunicators
if MPI.COMM_WORLD.Get_size() % 2 == 0:
nsub = 2
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
nsub = 3
else:
pass
subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub)
mask = np.repeat(np.arange(nsub), subsize)

# Replicate x as required in masked arrays
x_gpu = cp.asarray(par['x'])
if par['axis'] != 0:
x_gpu = cp.swapaxes(x_gpu, par['axis'], 0)
for isub in range(1, nsub):
x_gpu[(x_gpu.shape[0] // nsub) * isub:(x_gpu.shape[0] // nsub) * (isub + 1)] = x_gpu[:x_gpu.shape[0] // nsub]
if par['axis'] != 0:
x_gpu = np.swapaxes(x_gpu, 0, par['axis'])

arr = DistributedArray.to_dist(x=x_gpu, base_comm_nccl=nccl_comm, partition=par['partition'], mask=mask, axis=par['axis'])

# Global view
xloc = arr.asarray()
assert xloc.shape == x_gpu.shape

# Global masked view
xmaskedloc = arr.asarray(masked=True)
xmasked_shape = list(x_gpu.shape)
xmasked_shape[par['axis']] = int(xmasked_shape[par['axis']] // nsub)
assert xmaskedloc.shape == tuple(xmasked_shape)


@pytest.mark.mpi(min_size=2)
@pytest.mark.parametrize(
"par1, par2", [(par6, par7), (par6b, par7b), (par8, par9), (par8b, par9b)]
Expand Down