Skip to content

Commit 9693d9b

Browse files
committed
base_comm instantiation - suggested in PR
1 parent 3dc41fe commit 9693d9b

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ class DistributedArray:
124124
MPI Communicator over which array is distributed.
125125
Defaults to ``mpi4py.MPI.COMM_WORLD``.
126126
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional
127-
NCCL Communicator over which array is distributed.
127+
NCCL Communicator over which array is distributed. Whenever NCCL
128+
Communicator is provided, the base_comm will be set to MPI.COMM_WORLD.
128129
partition : :obj:`Partition`, optional
129130
Broadcast, UnsafeBroadcast, or Scatter the array. Defaults to ``Partition.SCATTER``.
130131
axis : :obj:`int`, optional
@@ -161,8 +162,11 @@ def __init__(self, global_shape: Union[Tuple, Integral],
161162

162163
self.dtype = dtype
163164
self._global_shape = _value_or_sized_to_tuple(global_shape)
164-
self._base_comm = base_comm
165165
self._base_comm_nccl = base_comm_nccl
166+
if base_comm_nccl is None:
167+
self._base_comm = base_comm
168+
else:
169+
self._base_comm = MPI.COMM_WORLD
166170
self._partition = partition
167171
self._axis = axis
168172
self._mask = mask

0 commit comments

Comments
 (0)