Skip to content

Commit ffd3707

Browse files
committed
private method calculate local_shapes for NCCL
1 parent e27b82a commit ffd3707

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,7 @@ def local_shapes(self):
340340
local_shapes : :obj:`list`
341341
"""
342342
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
343-
# gather tuple of shapes from every rank and copy from GPU to CPU
344-
all_tuples = self._allgather(self.local_shape).get()
345-
# NCCL returns the flat array that packs every tuple as 1-dimensional array
346-
# unpack each tuple from each rank
347-
tuple_len = len(self.local_shape)
348-
return [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
343+
return self._nccl_local_shapes(self.base_comm_nccl)
349344
else:
350345
return self._allgather(self.local_shape)
351346

@@ -380,12 +375,7 @@ def asarray(self, masked: bool = False):
380375
return self.local_array
381376

382377
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
383-
if masked:
384-
all_tuples = self._allgather_subcomm(self.local_shape).get()
385-
tuple_len = len(self.local_shape)
386-
local_shapes = [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
387-
else:
388-
local_shapes = self.local_shapes
378+
local_shapes = self._nccl_local_shapes(self.sub_comm if masked else self.base_comm_nccl)
389379
return nccl_asarray(self.sub_comm if masked else self.base_comm_nccl,
390380
self.local_array, local_shapes, self.axis)
391381
else:
@@ -560,6 +550,21 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None):
560550
self.base_comm.Recv(buf=recv_buf, source=source, tag=tag)
561551
return recv_buf
562552

553+
def _nccl_local_shapes(self, nccl_comm: NcclCommunicatorType):
554+
"""Get the the list of shapes of every GPU in the communicator
555+
"""
556+
# gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU
557+
if nccl_comm == self.sub_comm:
558+
all_tuples = self._allgather_subcomm(self.local_shape).get()
559+
else:
560+
assert (nccl_comm == self.base_comm_nccl)
561+
all_tuples = self._allgather(self.local_shape).get()
562+
# NCCL returns the flat array that packs every tuple as 1-dimensional array
563+
# unpack each tuple from each rank
564+
tuple_len = len(self.local_shape)
565+
local_shapes = [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
566+
return local_shapes
567+
563568
def __neg__(self):
564569
arr = DistributedArray(global_shape=self.global_shape,
565570
base_comm=self.base_comm,

0 commit comments

Comments
 (0)