Skip to content

Commit 18db078

Browse files
committed
removed now useless bcast and fixed mask in test
1 parent bd9ad37 commit 18db078

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
182182

183183
my_own_cols = self._rank_col_lens[self.rank]
184184
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
185-
x_arr = x_arr.astype(self.dtype)
186-
187-
X_local = self._layer_comm.bcast(x_arr if self._group_id == self._layer_id else None, root=self._layer_id)
185+
X_local = x_arr.astype(self.dtype)
188186
Y_local = ncp.vstack(
189187
self._layer_comm.allgather(
190188
ncp.matmul(self.A, X_local)

tests/test_matrixmult.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
8080
local_shapes=[K * cl_b for cl_b in all_local_col_len],
8181
partition=Partition.SCATTER,
8282
base_comm=comm,
83-
mask=[i // p_prime for i in range(size)],
83+
mask=[i % p_prime for i in range(size)],
8484
dtype=dtype
8585
)
8686

0 commit comments

Comments
 (0)