Skip to content

Commit f362436

Browse files
committed
feat: adapted all comm calls in DistributedArray to new method signatures
1 parent b8bcd29 commit f362436

File tree

1 file changed

+34
-15
lines changed

1 file changed

+34
-15
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
nccl_message = deps.nccl_import("the DistributedArray module")
1616

1717
if nccl_message is None and cupy_message is None:
18-
from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split
18+
from pylops_mpi.utils._nccl import nccl_asarray, nccl_split
1919
from cupy.cuda.nccl import NcclCommunicator
2020
else:
2121
NcclCommunicator = Any
@@ -204,10 +204,7 @@ def __setitem__(self, index, value):
204204
the specified index positions.
205205
"""
206206
if self.partition is Partition.BROADCAST:
207-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
208-
nccl_bcast(self.base_comm_nccl, self.local_array, index, value)
209-
else:
210-
self.local_array[index] = self.base_comm.bcast(value)
207+
self._bcast(self.local_array, index, value)
211208
else:
212209
self.local_array[index] = value
213210

@@ -343,7 +340,9 @@ def local_shapes(self):
343340
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
344341
return self._nccl_local_shapes(False)
345342
else:
346-
return self._allgather(self.local_shape)
343+
return self._allgather(self.base_comm,
344+
self.base_comm_nccl,
345+
self.local_shape)
347346

348347
@property
349348
def sub_comm(self):
@@ -383,7 +382,10 @@ def asarray(self, masked: bool = False):
383382
if masked:
384383
final_array = self._allgather_subcomm(self.local_array)
385384
else:
386-
final_array = self._allgather(self.local_array)
385+
final_array = self._allgather(self.base_comm,
386+
self.base_comm_nccl,
387+
self.local_array,
388+
engine=self.engine)
387389
return np.concatenate(final_array, axis=self.axis)
388390

389391
@classmethod
@@ -433,6 +435,7 @@ def to_dist(cls, x: NDArray,
433435
else:
434436
slices = [slice(None)] * x.ndim
435437
local_shapes = np.append([0], dist_array._allgather(
438+
base_comm, base_comm_nccl,
436439
dist_array.local_shape[axis]))
437440
sum_shapes = np.cumsum(local_shapes)
438441
slices[axis] = slice(sum_shapes[dist_array.rank],
@@ -480,7 +483,9 @@ def _nccl_local_shapes(self, masked: bool):
480483
if masked:
481484
all_tuples = self._allgather_subcomm(self.local_shape).get()
482485
else:
483-
all_tuples = self._allgather(self.local_shape).get()
486+
all_tuples = self._allgather(self.base_comm,
487+
self.base_comm_nccl,
488+
self.local_shape).get()
484489
# NCCL returns the flat array that packs every tuple as 1-dimensional array
485490
# unpack each tuple from each rank
486491
tuple_len = len(self.local_shape)
@@ -578,7 +583,9 @@ def dot(self, dist_array):
578583
y = DistributedArray.to_dist(x=dist_array.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
579584
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array
580585
# Flatten the local arrays and calculate dot product
581-
return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten()))
586+
return self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
587+
ncp.dot(x.local_array.flatten(), y.local_array.flatten()),
588+
engine=self.engine)
582589

583590
def _compute_vector_norm(self, local_array: NDArray,
584591
axis: int, ord: Optional[int] = None):
@@ -606,7 +613,9 @@ def _compute_vector_norm(self, local_array: NDArray,
606613
raise ValueError(f"norm-{ord} not possible for vectors")
607614
elif ord == 0:
608615
# Count non-zero then sum reduction
609-
recv_buf = self._allreduce_subcomm(ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64))
616+
recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
617+
ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64),
618+
engine=self.engine)
610619
elif ord == ncp.inf:
611620
# Calculate max followed by max reduction
612621
# CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly
@@ -615,10 +624,14 @@ def _compute_vector_norm(self, local_array: NDArray,
615624
if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled:
616625
# CuPy + non-CUDA-aware MPI: This will call non-buffered communication
617626
# which return a list of object - must be copied back to a GPU memory.
618-
recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX)
627+
recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
628+
send_buf.get(), recv_buf.get(),
629+
op=MPI.MAX, engine=self.engine)
619630
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
620631
else:
621-
recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX)
632+
recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
633+
send_buf, recv_buf, op=MPI.MAX,
634+
engine=self.engine)
622635
# TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL
623636
# the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it.
624637
# There may be a way to unify it - may be something to do with how we allocate the recv_buf.
@@ -629,14 +642,20 @@ def _compute_vector_norm(self, local_array: NDArray,
629642
# See the comment above in +infinity norm
630643
send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64)
631644
if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled:
632-
recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MIN)
645+
recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
646+
send_buf.get(), recv_buf.get(),
647+
op=MPI.MIN, engine=self.engine)
633648
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
634649
else:
635-
recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MIN)
650+
recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
651+
send_buf, recv_buf,
652+
op=MPI.MIN, engine=self.engine)
636653
if self.base_comm_nccl:
637654
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
638655
else:
639-
recv_buf = self._allreduce_subcomm(ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis))
656+
recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
657+
ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis),
658+
engine=self.engine)
640659
recv_buf = ncp.power(recv_buf, 1.0 / ord)
641660
return recv_buf
642661

0 commit comments

Comments
 (0)