Skip to content

Commit 5133c6b

Browse files
authored
Merge branch 'main' into feat-asarraymasked
2 parents e4604b1 + 9e25d8e commit 5133c6b

File tree

11 files changed

+1123
-48
lines changed

11 files changed

+1123
-48
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
nccl_message = deps.nccl_import("the DistributedArray module")
1515

1616
if nccl_message is None and cupy_message is None:
17-
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split
17+
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv
1818
from cupy.cuda.nccl import NcclCommunicator
1919
else:
2020
NcclCommunicator = Any
@@ -504,7 +504,7 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
504504
def _allgather(self, send_buf, recv_buf=None):
505505
"""Allgather operation
506506
"""
507-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
507+
if deps.nccl_enabled and self.base_comm_nccl:
508508
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
509509
else:
510510
if recv_buf is None:
@@ -521,6 +521,37 @@ def _allgather_subcomm(self, send_buf, recv_buf=None):
521521
if recv_buf is None:
522522
return self.sub_comm.allgather(send_buf)
523523
self.sub_comm.Allgather(send_buf, recv_buf)
524+
525+
def _send(self, send_buf, dest, count=None, tag=None):
526+
""" Send operation
527+
"""
528+
if deps.nccl_enabled and self.base_comm_nccl:
529+
if count is None:
530+
# assuming sending the whole array
531+
count = send_buf.size
532+
nccl_send(self.base_comm_nccl, send_buf, dest, count)
533+
else:
534+
self.base_comm.Send(send_buf, dest, tag)
535+
536+
def _recv(self, recv_buf=None, source=0, count=None, tag=None):
537+
""" Receive operation
538+
"""
539+
# NCCL must be called with recv_buf. Size cannot be inferred from
540+
# other arguments and thus cannot be dynamically allocated
541+
if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None:
542+
if recv_buf is not None:
543+
if count is None:
544+
# assuming data will take a space of the whole buffer
545+
count = recv_buf.size
546+
nccl_recv(self.base_comm_nccl, recv_buf, source, count)
547+
return recv_buf
548+
else:
549+
raise ValueError("Using recv with NCCL must also supply receiver buffer ")
550+
else:
551+
# MPI allows a receiver buffer to be optional
552+
if recv_buf is None:
553+
return self.base_comm.recv(source=source, tag=tag)
554+
self.base_comm.Recv(buf=recv_buf, source=source, tag=tag)
524555
return recv_buf
525556

526557
def __neg__(self):
@@ -560,6 +591,7 @@ def add(self, dist_array):
560591
self._check_mask(dist_array)
561592
SumArray = DistributedArray(global_shape=self.global_shape,
562593
base_comm=self.base_comm,
594+
base_comm_nccl=self.base_comm_nccl,
563595
dtype=self.dtype,
564596
partition=self.partition,
565597
local_shapes=self.local_shapes,
@@ -586,6 +618,7 @@ def multiply(self, dist_array):
586618

587619
ProductArray = DistributedArray(global_shape=self.global_shape,
588620
base_comm=self.base_comm,
621+
base_comm_nccl=self.base_comm_nccl,
589622
dtype=self.dtype,
590623
partition=self.partition,
591624
local_shapes=self.local_shapes,
@@ -736,6 +769,8 @@ def ravel(self, order: Optional[str] = "C"):
736769
"""
737770
local_shapes = [(np.prod(local_shape, axis=-1), ) for local_shape in self.local_shapes]
738771
arr = DistributedArray(global_shape=np.prod(self.global_shape),
772+
base_comm=self.base_comm,
773+
base_comm_nccl=self.base_comm_nccl,
739774
local_shapes=local_shapes,
740775
mask=self.mask,
741776
partition=self.partition,
@@ -764,41 +799,57 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
764799
-------
765800
ghosted_array : :obj:`numpy.ndarray`
766801
Ghosted Array
767-
768802
"""
769803
ghosted_array = self.local_array.copy()
804+
ncp = get_module(self.engine)
770805
if cells_front is not None:
771-
total_cells_front = self._allgather(cells_front) + [0]
806+
# cells_front is small array of int. Explicitly use MPI
807+
total_cells_front = self.base_comm.allgather(cells_front) + [0]
772808
# Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1)
773809
cells_front = total_cells_front[self.rank + 1]
810+
send_buf = ncp.take(self.local_array, ncp.arange(-cells_front, 0), axis=self.axis)
811+
recv_shapes = self.local_shapes
774812
if self.rank != 0:
775-
ghosted_array = np.concatenate([self.base_comm.recv(source=self.rank - 1, tag=1), ghosted_array],
776-
axis=self.axis)
777-
if self.rank != self.size - 1:
813+
# from receiver's perspective (rank), the recv buffer have the same shape as the sender's array (rank-1)
814+
# in every dimension except the shape at axis=self.axis
815+
recv_shape = list(recv_shapes[self.rank - 1])
816+
recv_shape[self.axis] = total_cells_front[self.rank]
817+
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
818+
# Transfer of ghost cells can be skipped if len(recv_buf) = 0
819+
# Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory
820+
if len(recv_buf) != 0:
821+
ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis)
822+
# The skip in sender is to match with what described in receiver
823+
if self.rank != self.size - 1 and len(send_buf) != 0:
778824
if cells_front > self.local_shape[self.axis]:
779825
raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} "
780826
f"should be > {cells_front}: dim({self.axis}) "
781827
f"{self.local_shape[self.axis]} < {cells_front}; "
782828
f"to achieve this use NUM_PROCESSES <= "
783829
f"{max(1, self.global_shape[self.axis] // cells_front)}")
784-
self.base_comm.send(np.take(self.local_array, np.arange(-cells_front, 0), axis=self.axis),
785-
dest=self.rank + 1, tag=1)
830+
self._send(send_buf, dest=self.rank + 1, tag=1)
786831
if cells_back is not None:
787-
total_cells_back = self._allgather(cells_back) + [0]
832+
total_cells_back = self.base_comm.allgather(cells_back) + [0]
788833
# Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1)
789834
cells_back = total_cells_back[self.rank - 1]
790-
if self.rank != 0:
835+
send_buf = ncp.take(self.local_array, ncp.arange(cells_back), axis=self.axis)
836+
# Same reasoning as sending cell front applied
837+
recv_shapes = self.local_shapes
838+
if self.rank != 0 and len(send_buf) != 0:
791839
if cells_back > self.local_shape[self.axis]:
792840
raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} "
793841
f"should be > {cells_back}: dim({self.axis}) "
794842
f"{self.local_shape[self.axis]} < {cells_back}; "
795843
f"to achieve this use NUM_PROCESSES <= "
796844
f"{max(1, self.global_shape[self.axis] // cells_back)}")
797-
self.base_comm.send(np.take(self.local_array, np.arange(cells_back), axis=self.axis),
798-
dest=self.rank - 1, tag=0)
845+
self._send(send_buf, dest=self.rank - 1, tag=0)
799846
if self.rank != self.size - 1:
800-
ghosted_array = np.append(ghosted_array, self.base_comm.recv(source=self.rank + 1, tag=0),
801-
axis=self.axis)
847+
recv_shape = list(recv_shapes[self.rank + 1])
848+
recv_shape[self.axis] = total_cells_back[self.rank]
849+
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
850+
if len(recv_buf) != 0:
851+
ghosted_array = ncp.append(ghosted_array, self._recv(recv_buf, source=self.rank + 1, tag=0),
852+
axis=self.axis)
802853
return ghosted_array
803854

804855
def __repr__(self):

pylops_mpi/LinearOperator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
8686
if self.Op:
8787
y = DistributedArray(global_shape=self.shape[0],
8888
base_comm=self.base_comm,
89+
base_comm_nccl=x.base_comm_nccl,
8990
partition=x.partition,
9091
axis=x.axis,
9192
engine=x.engine,
@@ -123,6 +124,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
123124
if self.Op:
124125
y = DistributedArray(global_shape=self.shape[1],
125126
base_comm=self.base_comm,
127+
base_comm_nccl=x.base_comm_nccl,
126128
partition=x.partition,
127129
axis=x.axis,
128130
engine=x.engine,

pylops_mpi/basicoperators/BlockDiag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(self, ops: Sequence[LinearOperator],
121121
@reshaped(forward=True, stacking=True)
122122
def _matvec(self, x: DistributedArray) -> DistributedArray:
123123
ncp = get_module(x.engine)
124-
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
124+
y = DistributedArray(global_shape=self.shape[0], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n,
125125
mask=self.mask, engine=x.engine, dtype=self.dtype)
126126
y1 = []
127127
for iop, oper in enumerate(self.ops):
@@ -133,7 +133,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
133133
@reshaped(forward=False, stacking=True)
134134
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
135135
ncp = get_module(x.engine)
136-
y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m,
136+
y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_m,
137137
mask=self.mask, engine=x.engine, dtype=self.dtype)
138138
y1 = []
139139
for iop, oper in enumerate(self.ops):

pylops_mpi/basicoperators/FirstDerivative.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,19 @@ def _register_multiplications(
129129
def _matvec(self, x: DistributedArray) -> DistributedArray:
130130
# If Partition.BROADCAST, then convert to Partition.SCATTER
131131
if x.partition is Partition.BROADCAST:
132-
x = DistributedArray.to_dist(x=x.local_array)
132+
x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl)
133133
return self._hmatvec(x)
134134

135135
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
136136
# If Partition.BROADCAST, then convert to Partition.SCATTER
137137
if x.partition is Partition.BROADCAST:
138-
x = DistributedArray.to_dist(x=x.local_array)
138+
x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl)
139139
return self._hrmatvec(x)
140140

141141
@reshaped
142142
def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
143143
ncp = get_module(x.engine)
144-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
144+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
145145
axis=x.axis, engine=x.engine, dtype=self.dtype)
146146
ghosted_x = x.add_ghost_cells(cells_back=1)
147147
y_forward = ghosted_x[1:] - ghosted_x[:-1]
@@ -153,7 +153,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
153153
@reshaped
154154
def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
155155
ncp = get_module(x.engine)
156-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
156+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
157157
axis=x.axis, engine=x.engine, dtype=self.dtype)
158158
y[:] = 0
159159
if self.rank == self.size - 1:
@@ -171,7 +171,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
171171
@reshaped
172172
def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
173173
ncp = get_module(x.engine)
174-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
174+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
175175
axis=x.axis, engine=x.engine, dtype=self.dtype)
176176
ghosted_x = x.add_ghost_cells(cells_front=1)
177177
y_backward = ghosted_x[1:] - ghosted_x[:-1]
@@ -183,7 +183,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
183183
@reshaped
184184
def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
185185
ncp = get_module(x.engine)
186-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
186+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
187187
axis=x.axis, engine=x.engine, dtype=self.dtype)
188188
y[:] = 0
189189
ghosted_x = x.add_ghost_cells(cells_back=1)
@@ -201,7 +201,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
201201
@reshaped
202202
def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
203203
ncp = get_module(x.engine)
204-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
204+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
205205
axis=x.axis, engine=x.engine, dtype=self.dtype)
206206
ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1)
207207
y_centered = 0.5 * (ghosted_x[2:] - ghosted_x[:-2])
@@ -221,7 +221,7 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
221221
@reshaped
222222
def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
223223
ncp = get_module(x.engine)
224-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
224+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
225225
axis=x.axis, engine=x.engine, dtype=self.dtype)
226226
y[:] = 0
227227

@@ -249,7 +249,7 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
249249
@reshaped
250250
def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
251251
ncp = get_module(x.engine)
252-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
252+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
253253
axis=x.axis, engine=x.engine, dtype=self.dtype)
254254
ghosted_x = x.add_ghost_cells(cells_front=2, cells_back=2)
255255
y_centered = (
@@ -276,7 +276,7 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
276276
@reshaped
277277
def _rmatvec_centered5(self, x: DistributedArray) -> DistributedArray:
278278
ncp = get_module(x.engine)
279-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
279+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
280280
axis=x.axis, engine=x.engine, dtype=self.dtype)
281281
y[:] = 0
282282
ghosted_x = x.add_ghost_cells(cells_back=4)

0 commit comments

Comments
 (0)