Skip to content

Commit bd9ad37

Browse files
committed
minor: fix mistake in plot_matrixmult
1 parent a110ff8 commit bd9ad37

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/plot_matrixmult.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@
152152
Aop = MPIMatrixMult(A_p, N, dtype="float32")
153153

154154
col_lens = comm.allgather(my_own_cols)
155-
x = DistributedArray(global_shape=K * N,
155+
total_cols = np.sum(col_lens)
156+
x = DistributedArray(global_shape=K * total_cols,
156157
local_shapes=[K * col_len for col_len in col_lens],
157158
partition=Partition.SCATTER,
158159
mask=[i % p_prime for i in range(comm.Get_size())],

0 commit comments

Comments
 (0)