Skip to content

Commit 7363431

Browse files
committed
cleanedup adjoint impl
1 parent 8142d44 commit 7363431

File tree

1 file changed

+25
-30
lines changed

1 file changed

+25
-30
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -246,22 +246,22 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
246246
ncp = get_module(x.engine)
247247
if x.partition != Partition.SCATTER:
248248
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
249-
local_shape = (self.N // self._P_prime) * ( self.M * self._P_prime // self.size)
250-
y = DistributedArray(global_shape=((self.N // self._P_prime) * self.M * self._P_prime),
249+
local_shape = ((self.N * self.M) // self.size)
250+
y = DistributedArray(global_shape=(self.N * self.M),
251251
mask=x.mask,
252-
local_shapes=[ local_shape for _ in range(self.size)],
252+
local_shapes=[local_shape] * self.size,
253253
partition=Partition.SCATTER,
254254
dtype=self.dtype)
255255

256256
x = x.local_array.reshape((self.A.shape[1], -1))
257-
c_local = np.zeros((self.A.shape[0], x.shape[1]))
257+
Y_local = np.zeros((self.A.shape[0], x.shape[1]))
258258
for k in range(self._P_prime):
259259
Atemp = self.A.copy() if self._col_id == k else np.empty_like(self.A)
260260
Xtemp = x.copy() if self._row_id == k else np.empty_like(x)
261261
self._row_comm.Bcast(Atemp, root=k)
262262
self._col_comm.Bcast(Xtemp, root=k)
263-
c_local += ncp.dot(Atemp, Xtemp)
264-
y[:] = c_local.flatten()
263+
Y_local += ncp.dot(Atemp, Xtemp)
264+
y[:] = Y_local.flatten()
265265
return y
266266

267267

@@ -270,38 +270,33 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
270270
if x.partition != Partition.SCATTER:
271271
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
272272

273-
local_shape = (self.K // self._P_prime) * (self.M * self._P_prime // self.size)
273+
local_shape = ((self.K * self.M ) // self.size)
274274
y = DistributedArray(
275-
global_shape=((self.K // self._P_prime) * self.M * self._P_prime),
275+
global_shape=(self.K * self.M),
276276
mask=x.mask,
277-
local_shapes=[local_shape for _ in range(self.size)],
277+
local_shapes=[local_shape] * self.size,
278278
partition=Partition.SCATTER,
279279
dtype=self.dtype,
280280
)
281281
x_reshaped = x.local_array.reshape((self.A.shape[0], -1))
282282
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
283-
c_local = np.zeros((self.A.shape[1], x_reshaped.shape[1]))
284-
P = self._P_prime
283+
Y_local = np.zeros((self.A.shape[1], x_reshaped.shape[1]))
285284

286-
for k in range(P):
287-
temps = {}
285+
for k in range(self._P_prime):
288286
requests = []
289-
for buf, owner, base, name in (
290-
(A_local, self._row_id, 100, 'A'),
291-
(x_reshaped, self._col_id, 200, 'B'),
292-
):
293-
tmp = np.empty_like(buf)
294-
temps[name] = tmp
295-
src, tag = k * P + owner, (base + k) * 1000 + self.rank
296-
requests.append(self.base_comm.Irecv(tmp, source=src, tag=tag))
297-
298-
if self.rank // P == k:
299-
fixed = self.rank % P
300-
for moving in range(P):
301-
dest = (fixed * P + moving) if name == 'A' else moving * P + fixed
302-
tag = (base + k) * 1000 + dest
303-
requests.append(self.base_comm.Isend(buf, dest=dest, tag=tag))
287+
ATtemp = np.empty_like(A_local)
288+
srcA = k * self._P_prime + self._row_id
289+
tagA = (100 + k) * 1000 + self.rank
290+
requests.append(self.base_comm.Irecv(ATtemp, source=srcA, tag=tagA))
291+
if self._row_id == k:
292+
fixed_col = self._col_id
293+
for moving_col in range(self._P_prime):
294+
destA = fixed_col * self._P_prime + moving_col
295+
tagA = (100 + k) * 1000 + destA
296+
requests.append(self.base_comm.Isend(A_local, dest=destA,tag=tagA))
297+
Xtemp = x_reshaped.copy() if self._row_id == k else np.empty_like(x_reshaped)
298+
requests.append(self._col_comm.Ibcast(Xtemp, root=k))
304299
MPI.Request.Waitall(requests)
305-
c_local += ncp.dot(temps['A'], temps['B'])
306-
y[:] = c_local.flatten()
300+
Y_local += ncp.dot(ATtemp, Xtemp)
301+
y[:] = Y_local.flatten()
307302
return y

0 commit comments

Comments
 (0)