Skip to content

Commit fb14c86

Browse files
committed
nccl support for SecondDerivative and test_derivative_nccl for first and second order
1 parent ae7190c commit fb14c86

File tree

3 files changed

+700
-15
lines changed

3 files changed

+700
-15
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def _allgather(self, send_buf, recv_buf=None):
506506
def _send(self, send_buf, dest, count=None, tag=None):
507507
""" Send operation
508508
"""
509-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
509+
if deps.nccl_enabled and self.base_comm_nccl:
510510
if count is None:
511511
# assuming sending the whole array
512512
count = send_buf.size
@@ -519,7 +519,7 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None):
519519
"""
520520
# NCCL must be called with recv_buf. Size cannot be inferred from
521521
# other arguments and thus cannot be dynamically allocated
522-
if deps.nccl_enabled and getattr(self, "base_comm_nccl") and recv_buf is not None:
522+
if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None:
523523
if recv_buf is not None:
524524
if count is None:
525525
# assuming data will take a space of the whole buffer
@@ -572,6 +572,7 @@ def add(self, dist_array):
572572
self._check_mask(dist_array)
573573
SumArray = DistributedArray(global_shape=self.global_shape,
574574
base_comm=self.base_comm,
575+
base_comm_nccl=self.base_comm_nccl,
575576
dtype=self.dtype,
576577
partition=self.partition,
577578
local_shapes=self.local_shapes,
@@ -598,6 +599,7 @@ def multiply(self, dist_array):
598599

599600
ProductArray = DistributedArray(global_shape=self.global_shape,
600601
base_comm=self.base_comm,
602+
base_comm_nccl=self.base_comm_nccl,
601603
dtype=self.dtype,
602604
partition=self.partition,
603605
local_shapes=self.local_shapes,
@@ -748,6 +750,7 @@ def ravel(self, order: Optional[str] = "C"):
748750
"""
749751
local_shapes = [(np.prod(local_shape, axis=-1), ) for local_shape in self.local_shapes]
750752
arr = DistributedArray(global_shape=np.prod(self.global_shape),
753+
base_comm_nccl=self.base_comm_nccl,
751754
local_shapes=local_shapes,
752755
mask=self.mask,
753756
partition=self.partition,

pylops_mpi/basicoperators/SecondDerivative.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,20 @@ def _register_multiplications(
112112
def _matvec(self, x: DistributedArray) -> DistributedArray:
113113
# If Partition.BROADCAST, then convert to Partition.SCATTER
114114
if x.partition is Partition.BROADCAST:
115-
x = DistributedArray.to_dist(x=x.local_array)
115+
x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl)
116116
return self._hmatvec(x)
117117

118118
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
119119
# If Partition.BROADCAST, then convert to Partition.SCATTER
120120
if x.partition is Partition.BROADCAST:
121-
x = DistributedArray.to_dist(x=x.local_array)
121+
x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl)
122122
return self._hrmatvec(x)
123123

124124
@reshaped
125125
def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
126126
ncp = get_module(x.engine)
127-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
128-
axis=x.axis, engine=x.engine, dtype=self.dtype)
127+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl,
128+
local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype)
129129
ghosted_x = x.add_ghost_cells(cells_back=2)
130130
y_forward = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2]
131131
if self.rank == self.size - 1:
@@ -136,7 +136,8 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
136136
@reshaped
137137
def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
138138
ncp = get_module(x.engine)
139-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype)
139+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl,
140+
local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype)
140141
y[:] = 0
141142
if self.rank == self.size - 1:
142143
y[:-2] += x[:-2]
@@ -162,8 +163,8 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
162163
@reshaped
163164
def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
164165
ncp = get_module(x.engine)
165-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
166-
axis=x.axis, engine=x.engine, dtype=self.dtype)
166+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl,
167+
local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype)
167168
ghosted_x = x.add_ghost_cells(cells_front=2)
168169
y_backward = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2]
169170
if self.rank == 0:
@@ -174,8 +175,8 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
174175
@reshaped
175176
def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
176177
ncp = get_module(x.engine)
177-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
178-
axis=x.axis, engine=x.engine, dtype=self.dtype)
178+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl,
179+
local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype)
179180
y[:] = 0
180181
ghosted_x = x.add_ghost_cells(cells_back=2)
181182
y_backward = ghosted_x[2:]
@@ -201,8 +202,8 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
201202
@reshaped
202203
def _matvec_centered(self, x: DistributedArray) -> DistributedArray:
203204
ncp = get_module(x.engine)
204-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
205-
axis=x.axis, engine=x.engine, dtype=self.dtype)
205+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl,
206+
local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype)
206207
ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1)
207208
y_centered = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2]
208209
if self.rank == 0:
@@ -221,8 +222,8 @@ def _matvec_centered(self, x: DistributedArray) -> DistributedArray:
221222
@reshaped
222223
def _rmatvec_centered(self, x: DistributedArray) -> DistributedArray:
223224
ncp = get_module(x.engine)
224-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
225-
axis=x.axis, engine=x.engine, dtype=self.dtype)
225+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl,
226+
local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype)
226227
y[:] = 0
227228
ghosted_x = x.add_ghost_cells(cells_back=2)
228229
y_centered = ghosted_x[1:-1]

0 commit comments

Comments
 (0)