Skip to content

Commit 9e25d8e

Browse files
authored
Merge pull request #137 from tharittk/nccl-vstack
Add NCCL support to add_ghost_cells and operators in /basicoperators
2 parents 28ccdf2 + d7d07ab commit 9e25d8e

File tree

11 files changed

+1124
-48
lines changed

11 files changed

+1124
-48
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
nccl_message = deps.nccl_import("the DistributedArray module")
1515

1616
if nccl_message is None and cupy_message is None:
17-
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split
17+
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv
1818
from cupy.cuda.nccl import NcclCommunicator
1919
else:
2020
NcclCommunicator = Any
@@ -495,14 +495,46 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
495495
def _allgather(self, send_buf, recv_buf=None):
496496
"""Allgather operation
497497
"""
498-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
498+
if deps.nccl_enabled and self.base_comm_nccl:
499499
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
500500
else:
501501
if recv_buf is None:
502502
return self.base_comm.allgather(send_buf)
503503
self.base_comm.Allgather(send_buf, recv_buf)
504504
return recv_buf
505505

506+
def _send(self, send_buf, dest, count=None, tag=None):
507+
""" Send operation
508+
"""
509+
if deps.nccl_enabled and self.base_comm_nccl:
510+
if count is None:
511+
# assuming sending the whole array
512+
count = send_buf.size
513+
nccl_send(self.base_comm_nccl, send_buf, dest, count)
514+
else:
515+
self.base_comm.Send(send_buf, dest, tag)
516+
517+
def _recv(self, recv_buf=None, source=0, count=None, tag=None):
518+
""" Receive operation
519+
"""
520+
# NCCL must be called with recv_buf. Size cannot be inferred from
521+
# other arguments and thus cannot be dynamically allocated
522+
if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None:
523+
if recv_buf is not None:
524+
if count is None:
525+
# assuming data will take a space of the whole buffer
526+
count = recv_buf.size
527+
nccl_recv(self.base_comm_nccl, recv_buf, source, count)
528+
return recv_buf
529+
else:
530+
raise ValueError("Using recv with NCCL must also supply receiver buffer ")
531+
else:
532+
# MPI allows a receiver buffer to be optional
533+
if recv_buf is None:
534+
return self.base_comm.recv(source=source, tag=tag)
535+
self.base_comm.Recv(buf=recv_buf, source=source, tag=tag)
536+
return recv_buf
537+
506538
def __neg__(self):
507539
arr = DistributedArray(global_shape=self.global_shape,
508540
base_comm=self.base_comm,
@@ -540,6 +572,7 @@ def add(self, dist_array):
540572
self._check_mask(dist_array)
541573
SumArray = DistributedArray(global_shape=self.global_shape,
542574
base_comm=self.base_comm,
575+
base_comm_nccl=self.base_comm_nccl,
543576
dtype=self.dtype,
544577
partition=self.partition,
545578
local_shapes=self.local_shapes,
@@ -566,6 +599,7 @@ def multiply(self, dist_array):
566599

567600
ProductArray = DistributedArray(global_shape=self.global_shape,
568601
base_comm=self.base_comm,
602+
base_comm_nccl=self.base_comm_nccl,
569603
dtype=self.dtype,
570604
partition=self.partition,
571605
local_shapes=self.local_shapes,
@@ -716,6 +750,8 @@ def ravel(self, order: Optional[str] = "C"):
716750
"""
717751
local_shapes = [(np.prod(local_shape, axis=-1), ) for local_shape in self.local_shapes]
718752
arr = DistributedArray(global_shape=np.prod(self.global_shape),
753+
base_comm=self.base_comm,
754+
base_comm_nccl=self.base_comm_nccl,
719755
local_shapes=local_shapes,
720756
mask=self.mask,
721757
partition=self.partition,
@@ -744,41 +780,57 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
744780
-------
745781
ghosted_array : :obj:`numpy.ndarray`
746782
Ghosted Array
747-
748783
"""
749784
ghosted_array = self.local_array.copy()
785+
ncp = get_module(self.engine)
750786
if cells_front is not None:
751-
total_cells_front = self._allgather(cells_front) + [0]
787+
# cells_front is small array of int. Explicitly use MPI
788+
total_cells_front = self.base_comm.allgather(cells_front) + [0]
752789
# Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1)
753790
cells_front = total_cells_front[self.rank + 1]
791+
send_buf = ncp.take(self.local_array, ncp.arange(-cells_front, 0), axis=self.axis)
792+
recv_shapes = self.local_shapes
754793
if self.rank != 0:
755-
ghosted_array = np.concatenate([self.base_comm.recv(source=self.rank - 1, tag=1), ghosted_array],
756-
axis=self.axis)
757-
if self.rank != self.size - 1:
794+
# from receiver's perspective (rank), the recv buffer have the same shape as the sender's array (rank-1)
795+
# in every dimension except the shape at axis=self.axis
796+
recv_shape = list(recv_shapes[self.rank - 1])
797+
recv_shape[self.axis] = total_cells_front[self.rank]
798+
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
799+
# Transfer of ghost cells can be skipped if len(recv_buf) = 0
800+
# Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory
801+
if len(recv_buf) != 0:
802+
ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis)
803+
# The skip in sender is to match with what described in receiver
804+
if self.rank != self.size - 1 and len(send_buf) != 0:
758805
if cells_front > self.local_shape[self.axis]:
759806
raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} "
760807
f"should be > {cells_front}: dim({self.axis}) "
761808
f"{self.local_shape[self.axis]} < {cells_front}; "
762809
f"to achieve this use NUM_PROCESSES <= "
763810
f"{max(1, self.global_shape[self.axis] // cells_front)}")
764-
self.base_comm.send(np.take(self.local_array, np.arange(-cells_front, 0), axis=self.axis),
765-
dest=self.rank + 1, tag=1)
811+
self._send(send_buf, dest=self.rank + 1, tag=1)
766812
if cells_back is not None:
767-
total_cells_back = self._allgather(cells_back) + [0]
813+
total_cells_back = self.base_comm.allgather(cells_back) + [0]
768814
# Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1)
769815
cells_back = total_cells_back[self.rank - 1]
770-
if self.rank != 0:
816+
send_buf = ncp.take(self.local_array, ncp.arange(cells_back), axis=self.axis)
817+
# Same reasoning as sending cell front applied
818+
recv_shapes = self.local_shapes
819+
if self.rank != 0 and len(send_buf) != 0:
771820
if cells_back > self.local_shape[self.axis]:
772821
raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} "
773822
f"should be > {cells_back}: dim({self.axis}) "
774823
f"{self.local_shape[self.axis]} < {cells_back}; "
775824
f"to achieve this use NUM_PROCESSES <= "
776825
f"{max(1, self.global_shape[self.axis] // cells_back)}")
777-
self.base_comm.send(np.take(self.local_array, np.arange(cells_back), axis=self.axis),
778-
dest=self.rank - 1, tag=0)
826+
self._send(send_buf, dest=self.rank - 1, tag=0)
779827
if self.rank != self.size - 1:
780-
ghosted_array = np.append(ghosted_array, self.base_comm.recv(source=self.rank + 1, tag=0),
781-
axis=self.axis)
828+
recv_shape = list(recv_shapes[self.rank + 1])
829+
recv_shape[self.axis] = total_cells_back[self.rank]
830+
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
831+
if len(recv_buf) != 0:
832+
ghosted_array = ncp.append(ghosted_array, self._recv(recv_buf, source=self.rank + 1, tag=0),
833+
axis=self.axis)
782834
return ghosted_array
783835

784836
def __repr__(self):

pylops_mpi/LinearOperator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
8686
if self.Op:
8787
y = DistributedArray(global_shape=self.shape[0],
8888
base_comm=self.base_comm,
89+
base_comm_nccl=x.base_comm_nccl,
8990
partition=x.partition,
9091
axis=x.axis,
9192
engine=x.engine,
@@ -123,6 +124,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
123124
if self.Op:
124125
y = DistributedArray(global_shape=self.shape[1],
125126
base_comm=self.base_comm,
127+
base_comm_nccl=x.base_comm_nccl,
126128
partition=x.partition,
127129
axis=x.axis,
128130
engine=x.engine,

pylops_mpi/basicoperators/BlockDiag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(self, ops: Sequence[LinearOperator],
121121
@reshaped(forward=True, stacking=True)
122122
def _matvec(self, x: DistributedArray) -> DistributedArray:
123123
ncp = get_module(x.engine)
124-
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
124+
y = DistributedArray(global_shape=self.shape[0], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n,
125125
mask=self.mask, engine=x.engine, dtype=self.dtype)
126126
y1 = []
127127
for iop, oper in enumerate(self.ops):
@@ -133,7 +133,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
133133
@reshaped(forward=False, stacking=True)
134134
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
135135
ncp = get_module(x.engine)
136-
y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m,
136+
y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_m,
137137
mask=self.mask, engine=x.engine, dtype=self.dtype)
138138
y1 = []
139139
for iop, oper in enumerate(self.ops):

pylops_mpi/basicoperators/FirstDerivative.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,19 @@ def _register_multiplications(
129129
def _matvec(self, x: DistributedArray) -> DistributedArray:
130130
# If Partition.BROADCAST, then convert to Partition.SCATTER
131131
if x.partition is Partition.BROADCAST:
132-
x = DistributedArray.to_dist(x=x.local_array)
132+
x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl)
133133
return self._hmatvec(x)
134134

135135
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
136136
# If Partition.BROADCAST, then convert to Partition.SCATTER
137137
if x.partition is Partition.BROADCAST:
138-
x = DistributedArray.to_dist(x=x.local_array)
138+
x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl)
139139
return self._hrmatvec(x)
140140

141141
@reshaped
142142
def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
143143
ncp = get_module(x.engine)
144-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
144+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
145145
axis=x.axis, engine=x.engine, dtype=self.dtype)
146146
ghosted_x = x.add_ghost_cells(cells_back=1)
147147
y_forward = ghosted_x[1:] - ghosted_x[:-1]
@@ -153,7 +153,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
153153
@reshaped
154154
def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
155155
ncp = get_module(x.engine)
156-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
156+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
157157
axis=x.axis, engine=x.engine, dtype=self.dtype)
158158
y[:] = 0
159159
if self.rank == self.size - 1:
@@ -171,7 +171,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
171171
@reshaped
172172
def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
173173
ncp = get_module(x.engine)
174-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
174+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
175175
axis=x.axis, engine=x.engine, dtype=self.dtype)
176176
ghosted_x = x.add_ghost_cells(cells_front=1)
177177
y_backward = ghosted_x[1:] - ghosted_x[:-1]
@@ -183,7 +183,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
183183
@reshaped
184184
def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
185185
ncp = get_module(x.engine)
186-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
186+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
187187
axis=x.axis, engine=x.engine, dtype=self.dtype)
188188
y[:] = 0
189189
ghosted_x = x.add_ghost_cells(cells_back=1)
@@ -201,7 +201,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
201201
@reshaped
202202
def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
203203
ncp = get_module(x.engine)
204-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
204+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
205205
axis=x.axis, engine=x.engine, dtype=self.dtype)
206206
ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1)
207207
y_centered = 0.5 * (ghosted_x[2:] - ghosted_x[:-2])
@@ -221,7 +221,7 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
221221
@reshaped
222222
def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
223223
ncp = get_module(x.engine)
224-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
224+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
225225
axis=x.axis, engine=x.engine, dtype=self.dtype)
226226
y[:] = 0
227227

@@ -249,7 +249,7 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
249249
@reshaped
250250
def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
251251
ncp = get_module(x.engine)
252-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
252+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
253253
axis=x.axis, engine=x.engine, dtype=self.dtype)
254254
ghosted_x = x.add_ghost_cells(cells_front=2, cells_back=2)
255255
y_centered = (
@@ -276,7 +276,7 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
276276
@reshaped
277277
def _rmatvec_centered5(self, x: DistributedArray) -> DistributedArray:
278278
ncp = get_module(x.engine)
279-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
279+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
280280
axis=x.axis, engine=x.engine, dtype=self.dtype)
281281
y[:] = 0
282282
ghosted_x = x.add_ghost_cells(cells_back=4)

0 commit comments

Comments
 (0)