Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,7 @@ def local_shapes(self):
local_shapes : :obj:`list`
"""
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
# gather tuple of shapes from every rank and copy from GPU to CPU
all_tuples = self._allgather(self.local_shape).get()
# NCCL returns the flat array that packs every tuple as 1-dimensional array
# unpack each tuple from each rank
tuple_len = len(self.local_shape)
return [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
return self._nccl_local_shapes(self.base_comm_nccl)
else:
return self._allgather(self.local_shape)

Expand Down Expand Up @@ -380,8 +375,9 @@ def asarray(self, masked: bool = False):
return self.local_array

if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
return nccl_asarray(self.sub_comm if masked else self.base_comm,
self.local_array, self.local_shapes, self.axis)
local_shapes = self._nccl_local_shapes(self.sub_comm if masked else self.base_comm_nccl)
return nccl_asarray(self.sub_comm if masked else self.base_comm_nccl,
self.local_array, local_shapes, self.axis)
else:
# Gather all the local arrays and apply concatenation.
if masked:
Expand Down Expand Up @@ -554,6 +550,21 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None):
self.base_comm.Recv(buf=recv_buf, source=source, tag=tag)
return recv_buf

def _nccl_local_shapes(self, nccl_comm: NcclCommunicatorType):
"""Get the the list of shapes of every GPU in the communicator
"""
# gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU
if nccl_comm == self.sub_comm:
all_tuples = self._allgather_subcomm(self.local_shape).get()
else:
assert (nccl_comm == self.base_comm_nccl)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like asserts in code, they should only be used in tests... why not passing comm and masked and then have the if masked... as in the code you deleted, so we don't do twice the same checks?

I am also not so sure why before we had

else:
    local_shapes = self.local_shapes

but now also for the case without subcomm we repeat the creation of local_shapes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Earlier, I have this

else:
    local_shapes = self.local_shapes

so that I can have single return statement return local_shapes , but that may look unnecessary.
Now I change the code to look like this:

    def _nccl_local_shapes(self, masked: bool):
        if masked:
            all_tuples = self._allgather_subcomm(self.local_shape).get()
        else:
            all_tuples = self._allgather(self.local_shape).get()
        tuple_len = len(self.local_shape)
        local_shapes = [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
        return local_shapes

It takes only masked: bool. We can assume if masked=True then we will use subcomm and we don't have to pass it.
And then have local_shapes() call the _nccl_local_shapes

    @property
    def local_shapes(self):
        if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
            return self._nccl_local_shapes(False)
        else:
            return self._allgather(self.local_shape)

It is because I want to have this unpacking appear only in one place.

        tuple_len = len(self.local_shape)
        local_shapes = [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]

all_tuples = self._allgather(self.local_shape).get()
# NCCL returns the flat array that packs every tuple as 1-dimensional array
# unpack each tuple from each rank
tuple_len = len(self.local_shape)
local_shapes = [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
return local_shapes

def __neg__(self):
arr = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
Expand Down
2 changes: 1 addition & 1 deletion pylops_mpi/utils/_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray:
)
if recv_buf is None:
recv_buf = cp.zeros(
MPI.COMM_WORLD.Get_size() * send_buf.size,
nccl_comm.size() * send_buf.size,
dtype=send_buf.dtype,
)
nccl_comm.allGather(
Expand Down
36 changes: 36 additions & 0 deletions tests_nccl/test_distributedarray_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,42 @@ def test_distributed_norm_nccl(par):
assert_allclose(arr.norm().get(), np.linalg.norm(par["x"].flatten()), rtol=1e-13)


@pytest.mark.mpi(min_size=2)
@pytest.mark.parametrize("par", [(par6), (par8)])
def test_distributed_masked_nccl(par):
"""Test Asarray with masked array"""
# Number of subcommunicators
if MPI.COMM_WORLD.Get_size() % 2 == 0:
nsub = 2
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
nsub = 3
else:
pass
subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub)
mask = np.repeat(np.arange(nsub), subsize)

# Replicate x as required in masked arrays
x_gpu = cp.asarray(par['x'])
if par['axis'] != 0:
x_gpu = cp.swapaxes(x_gpu, par['axis'], 0)
for isub in range(1, nsub):
x_gpu[(x_gpu.shape[0] // nsub) * isub:(x_gpu.shape[0] // nsub) * (isub + 1)] = x_gpu[:x_gpu.shape[0] // nsub]
if par['axis'] != 0:
x_gpu = np.swapaxes(x_gpu, 0, par['axis'])

arr = DistributedArray.to_dist(x=x_gpu, base_comm_nccl=nccl_comm, partition=par['partition'], mask=mask, axis=par['axis'])

# Global view
xloc = arr.asarray()
assert xloc.shape == x_gpu.shape

# Global masked view
xmaskedloc = arr.asarray(masked=True)
xmasked_shape = list(x_gpu.shape)
xmasked_shape[par['axis']] = int(xmasked_shape[par['axis']] // nsub)
assert xmaskedloc.shape == tuple(xmasked_shape)


@pytest.mark.mpi(min_size=2)
@pytest.mark.parametrize(
"par1, par2", [(par6, par7), (par6b, par7b), (par8, par9), (par8b, par9b)]
Expand Down