Skip to content

Commit d7d07ab

Browse files
committed
explicitly pass x.base_comm to DistributedArray as suggested in PR
1 parent fb14c86 commit d7d07ab

File tree

6 files changed

+24
-22
lines changed

6 files changed

+24
-22
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,7 @@ def ravel(self, order: Optional[str] = "C"):
750750
"""
751751
local_shapes = [(np.prod(local_shape, axis=-1), ) for local_shape in self.local_shapes]
752752
arr = DistributedArray(global_shape=np.prod(self.global_shape),
753+
base_comm=self.base_comm,
753754
base_comm_nccl=self.base_comm_nccl,
754755
local_shapes=local_shapes,
755756
mask=self.mask,

pylops_mpi/basicoperators/BlockDiag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(self, ops: Sequence[LinearOperator],
121121
@reshaped(forward=True, stacking=True)
122122
def _matvec(self, x: DistributedArray) -> DistributedArray:
123123
ncp = get_module(x.engine)
124-
y = DistributedArray(global_shape=self.shape[0], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n,
124+
y = DistributedArray(global_shape=self.shape[0], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n,
125125
mask=self.mask, engine=x.engine, dtype=self.dtype)
126126
y1 = []
127127
for iop, oper in enumerate(self.ops):
@@ -133,7 +133,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
133133
@reshaped(forward=False, stacking=True)
134134
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
135135
ncp = get_module(x.engine)
136-
y = DistributedArray(global_shape=self.shape[1], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_m,
136+
y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_m,
137137
mask=self.mask, engine=x.engine, dtype=self.dtype)
138138
y1 = []
139139
for iop, oper in enumerate(self.ops):

pylops_mpi/basicoperators/FirstDerivative.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,19 @@ def _register_multiplications(
129129
def _matvec(self, x: DistributedArray) -> DistributedArray:
130130
# If Partition.BROADCAST, then convert to Partition.SCATTER
131131
if x.partition is Partition.BROADCAST:
132-
x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl)
132+
x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl)
133133
return self._hmatvec(x)
134134

135135
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
136136
# If Partition.BROADCAST, then convert to Partition.SCATTER
137137
if x.partition is Partition.BROADCAST:
138-
x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl)
138+
x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl)
139139
return self._hrmatvec(x)
140140

141141
@reshaped
142142
def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
143143
ncp = get_module(x.engine)
144-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
144+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
145145
axis=x.axis, engine=x.engine, dtype=self.dtype)
146146
ghosted_x = x.add_ghost_cells(cells_back=1)
147147
y_forward = ghosted_x[1:] - ghosted_x[:-1]
@@ -153,7 +153,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
153153
@reshaped
154154
def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
155155
ncp = get_module(x.engine)
156-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
156+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
157157
axis=x.axis, engine=x.engine, dtype=self.dtype)
158158
y[:] = 0
159159
if self.rank == self.size - 1:
@@ -171,7 +171,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
171171
@reshaped
172172
def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
173173
ncp = get_module(x.engine)
174-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
174+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
175175
axis=x.axis, engine=x.engine, dtype=self.dtype)
176176
ghosted_x = x.add_ghost_cells(cells_front=1)
177177
y_backward = ghosted_x[1:] - ghosted_x[:-1]
@@ -183,7 +183,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
183183
@reshaped
184184
def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
185185
ncp = get_module(x.engine)
186-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
186+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
187187
axis=x.axis, engine=x.engine, dtype=self.dtype)
188188
y[:] = 0
189189
ghosted_x = x.add_ghost_cells(cells_back=1)
@@ -201,7 +201,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
201201
@reshaped
202202
def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
203203
ncp = get_module(x.engine)
204-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
204+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
205205
axis=x.axis, engine=x.engine, dtype=self.dtype)
206206
ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1)
207207
y_centered = 0.5 * (ghosted_x[2:] - ghosted_x[:-2])
@@ -221,7 +221,7 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
221221
@reshaped
222222
def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
223223
ncp = get_module(x.engine)
224-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
224+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
225225
axis=x.axis, engine=x.engine, dtype=self.dtype)
226226
y[:] = 0
227227

@@ -249,7 +249,7 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
249249
@reshaped
250250
def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
251251
ncp = get_module(x.engine)
252-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
252+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
253253
axis=x.axis, engine=x.engine, dtype=self.dtype)
254254
ghosted_x = x.add_ghost_cells(cells_front=2, cells_back=2)
255255
y_centered = (
@@ -276,7 +276,7 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
276276
@reshaped
277277
def _rmatvec_centered5(self, x: DistributedArray) -> DistributedArray:
278278
ncp = get_module(x.engine)
279-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
279+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
280280
axis=x.axis, engine=x.engine, dtype=self.dtype)
281281
y[:] = 0
282282
ghosted_x = x.add_ghost_cells(cells_back=4)

pylops_mpi/basicoperators/SecondDerivative.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,19 +112,19 @@ def _register_multiplications(
112112
def _matvec(self, x: DistributedArray) -> DistributedArray:
113113
# If Partition.BROADCAST, then convert to Partition.SCATTER
114114
if x.partition is Partition.BROADCAST:
115-
x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl)
115+
x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl)
116116
return self._hmatvec(x)
117117

118118
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
119119
# If Partition.BROADCAST, then convert to Partition.SCATTER
120120
if x.partition is Partition.BROADCAST:
121-
x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl)
121+
x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl)
122122
return self._hrmatvec(x)
123123

124124
@reshaped
125125
def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
126126
ncp = get_module(x.engine)
127-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl,
127+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl,
128128
local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype)
129129
ghosted_x = x.add_ghost_cells(cells_back=2)
130130
y_forward = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2]
@@ -136,7 +136,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
136136
@reshaped
137137
def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
138138
ncp = get_module(x.engine)
139-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl,
139+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl,
140140
local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype)
141141
y[:] = 0
142142
if self.rank == self.size - 1:
@@ -163,7 +163,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
163163
@reshaped
164164
def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
165165
ncp = get_module(x.engine)
166-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl,
166+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl,
167167
local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype)
168168
ghosted_x = x.add_ghost_cells(cells_front=2)
169169
y_backward = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2]
@@ -175,7 +175,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
175175
@reshaped
176176
def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
177177
ncp = get_module(x.engine)
178-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl,
178+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl,
179179
local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype)
180180
y[:] = 0
181181
ghosted_x = x.add_ghost_cells(cells_back=2)
@@ -202,7 +202,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
202202
@reshaped
203203
def _matvec_centered(self, x: DistributedArray) -> DistributedArray:
204204
ncp = get_module(x.engine)
205-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl,
205+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl,
206206
local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype)
207207
ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1)
208208
y_centered = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2]
@@ -222,7 +222,7 @@ def _matvec_centered(self, x: DistributedArray) -> DistributedArray:
222222
@reshaped
223223
def _rmatvec_centered(self, x: DistributedArray) -> DistributedArray:
224224
ncp = get_module(x.engine)
225-
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl,
225+
y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl,
226226
local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype)
227227
y[:] = 0
228228
ghosted_x = x.add_ghost_cells(cells_back=2)

pylops_mpi/basicoperators/VStack.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
130130
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
131131
f"Got {x.partition} instead...")
132132
# the output y should use NCCL if the operand x uses it
133-
y = DistributedArray(global_shape=self.shape[0], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n,
133+
y = DistributedArray(global_shape=self.shape[0], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n,
134134
engine=x.engine, dtype=self.dtype)
135135
y1 = []
136136
for iop, oper in enumerate(self.ops):
@@ -141,7 +141,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
141141
@reshaped(forward=False, stacking=True)
142142
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
143143
ncp = get_module(x.engine)
144-
y = DistributedArray(global_shape=self.shape[1], base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST,
144+
y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST,
145145
engine=x.engine, dtype=self.dtype)
146146
y1 = []
147147
for iop, oper in enumerate(self.ops):

pylops_mpi/utils/decorators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def wrapper(self, x: DistributedArray):
5454
local_shapes = None
5555
global_shape = getattr(self, "dims")
5656
arr = DistributedArray(global_shape=global_shape,
57+
base_comm=x.base_comm,
5758
base_comm_nccl=x.base_comm_nccl,
5859
local_shapes=local_shapes, axis=0,
5960
engine=x.engine, dtype=x.dtype)

0 commit comments

Comments
 (0)