Skip to content

Commit 740030d

Browse files
committed
Minor cosmetic changes
1 parent 22cde7b commit 740030d

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,13 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
104104
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
105105
m, b = A_local.shape
106106
pad = (-m) % self._P_prime
107-
r = (m + pad) // self._P_prime
108-
A_pad = np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=self.dtype.type(0.0))
109-
A_batch = A_pad.reshape(self._P_prime, r, b)
107+
A_pad = A_local if pad <= 0 else np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=self.dtype.type(0.0))
108+
batch_sz = (m + pad) // self._P_prime
109+
A_batch = A_pad.reshape(self._P_prime, batch_sz, b)
110110

111111
Y_batch = ncp.matmul(A_batch, X_tile)
112-
Y_pad = Y_batch.reshape(r * self._P_prime, -1)
113-
y_local = Y_pad[:m, :]
112+
Y_pad = Y_batch.reshape(batch_sz * self._P_prime, -1)
113+
y_local = Y_pad[:A_local.shape[0], :]
114114
y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM)
115115
y[:] = y_layer.flatten()
116116
return y

0 commit comments

Comments
 (0)