Skip to content

Commit 287fcb3

Browse files
committed
revised local shapes calculation flow
1 parent ffd3707 commit 287fcb3

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def local_shapes(self):
340340
local_shapes : :obj:`list`
341341
"""
342342
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
343-
return self._nccl_local_shapes(self.base_comm_nccl)
343+
return self._nccl_local_shapes(False)
344344
else:
345345
return self._allgather(self.local_shape)
346346

@@ -375,9 +375,8 @@ def asarray(self, masked: bool = False):
375375
return self.local_array
376376

377377
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
378-
local_shapes = self._nccl_local_shapes(self.sub_comm if masked else self.base_comm_nccl)
379378
return nccl_asarray(self.sub_comm if masked else self.base_comm_nccl,
380-
self.local_array, local_shapes, self.axis)
379+
self.local_array, self._nccl_local_shapes(masked), self.axis)
381380
else:
382381
# Gather all the local arrays and apply concatenation.
383382
if masked:
@@ -550,14 +549,13 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None):
550549
self.base_comm.Recv(buf=recv_buf, source=source, tag=tag)
551550
return recv_buf
552551

553-
def _nccl_local_shapes(self, nccl_comm: NcclCommunicatorType):
552+
def _nccl_local_shapes(self, masked: bool):
554553
"""Get the the list of shapes of every GPU in the communicator
555554
"""
556555
# gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU
557-
if nccl_comm == self.sub_comm:
556+
if masked:
558557
all_tuples = self._allgather_subcomm(self.local_shape).get()
559558
else:
560-
assert (nccl_comm == self.base_comm_nccl)
561559
all_tuples = self._allgather(self.local_shape).get()
562560
# NCCL returns the flat array that packs every tuple as 1-dimensional array
563561
# unpack each tuple from each rank

0 commit comments

Comments
 (0)