diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 590d7005..a0e4eab7 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -340,12 +340,7 @@ def local_shapes(self): local_shapes : :obj:`list` """ if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - # gather tuple of shapes from every rank and copy from GPU to CPU - all_tuples = self._allgather(self.local_shape).get() - # NCCL returns the flat array that packs every tuple as 1-dimensional array - # unpack each tuple from each rank - tuple_len = len(self.local_shape) - return [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)] + return self._nccl_local_shapes(False) else: return self._allgather(self.local_shape) @@ -380,8 +375,8 @@ def asarray(self, masked: bool = False): return self.local_array if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_asarray(self.sub_comm if masked else self.base_comm, - self.local_array, self.local_shapes, self.axis) + return nccl_asarray(self.sub_comm if masked else self.base_comm_nccl, + self.local_array, self._nccl_local_shapes(masked), self.axis) else: # Gather all the local arrays and apply concatenation. if masked: @@ -554,6 +549,20 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None): self.base_comm.Recv(buf=recv_buf, source=source, tag=tag) return recv_buf + def _nccl_local_shapes(self, masked: bool): + """Get the the list of shapes of every GPU in the communicator + """ + # gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU + if masked: + all_tuples = self._allgather_subcomm(self.local_shape).get() + else: + all_tuples = self._allgather(self.local_shape).get() + # NCCL returns the flat array that packs every tuple as 1-dimensional array + # unpack each tuple from each rank + tuple_len = len(self.local_shape) + local_shapes = [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)] + return local_shapes + def __neg__(self): arr = DistributedArray(global_shape=self.global_shape, base_comm=self.base_comm, diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index d183fc58..f2fe7d9a 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -149,7 +149,7 @@ def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray: ) if recv_buf is None: recv_buf = cp.zeros( - MPI.COMM_WORLD.Get_size() * send_buf.size, + nccl_comm.size() * send_buf.size, dtype=send_buf.dtype, ) nccl_comm.allGather( diff --git a/tests_nccl/test_distributedarray_nccl.py b/tests_nccl/test_distributedarray_nccl.py index 3478c8a8..7c3f510b 100644 --- a/tests_nccl/test_distributedarray_nccl.py +++ b/tests_nccl/test_distributedarray_nccl.py @@ -307,6 +307,42 @@ def test_distributed_norm_nccl(par): assert_allclose(arr.norm().get(), np.linalg.norm(par["x"].flatten()), rtol=1e-13) +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par6), (par8)]) +def test_distributed_masked_nccl(par): + """Test Asarray with masked array""" + # Number of subcommunicators + if MPI.COMM_WORLD.Get_size() % 2 == 0: + nsub = 2 + elif MPI.COMM_WORLD.Get_size() % 3 == 0: + nsub = 3 + else: + pass + subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub) + mask = np.repeat(np.arange(nsub), subsize) + + # Replicate x as required in masked arrays + x_gpu = cp.asarray(par['x']) + if par['axis'] != 0: + x_gpu = cp.swapaxes(x_gpu, par['axis'], 0) + for isub in range(1, nsub): + x_gpu[(x_gpu.shape[0] // nsub) * isub:(x_gpu.shape[0] // nsub) * (isub + 1)] = x_gpu[:x_gpu.shape[0] // nsub] + if par['axis'] != 0: + x_gpu = np.swapaxes(x_gpu, 0, par['axis']) + + arr = DistributedArray.to_dist(x=x_gpu, base_comm_nccl=nccl_comm, partition=par['partition'], mask=mask, axis=par['axis']) + + # Global view + xloc = arr.asarray() + assert xloc.shape == x_gpu.shape + + # Global masked view + xmaskedloc = arr.asarray(masked=True) + xmasked_shape = list(x_gpu.shape) + xmasked_shape[par['axis']] = int(xmasked_shape[par['axis']] // nsub) + assert xmaskedloc.shape == tuple(xmasked_shape) + + @pytest.mark.mpi(min_size=2) @pytest.mark.parametrize( "par1, par2", [(par6, par7), (par6b, par7b), (par8, par9), (par8b, par9b)]