Skip to content

Commit 58d3ceb

Browse files
committed
Cleanup
1 parent dc00226 commit 58d3ceb

File tree

2 files changed

+381
-169
lines changed

2 files changed

+381
-169
lines changed
Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from mpi4py import MPI
44

55
import pylops_mpi
6-
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
6+
from pylops_mpi.basicoperators.MatrixMult import (local_block_spit,
7+
block_gather,
8+
MPISummaMatrixMult)
79

810
comm = MPI.COMM_WORLD
911
rank = comm.Get_rank()
@@ -23,22 +25,25 @@
2325
A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape)
2426
B_data = np.arange(int(B_shape[0] * B_shape[1])).reshape(B_shape)
2527

26-
i, j = divmod(rank, p_prime)
27-
A_local, (N_new, K_new) = MPIMatrixMult.block_distribute(A_data, i, j,comm)
28-
B_local, (K_new, M_new) = MPIMatrixMult.block_distribute(B_data, i, j,comm)
28+
A_slice = local_block_spit(A_shape, rank, comm)
29+
B_slice = local_block_spit(B_shape, rank, comm)
30+
A_local = A_data[A_slice]
31+
B_local = B_data[B_slice]
32+
# A_local, (N_new, K_new) = block_distribute(A_data,rank, comm)
33+
# B_local, (K_new, M_new) = block_distribute(B_data,rank, comm)
2934

3035
B_dist = pylops_mpi.DistributedArray(global_shape=(K * M),
3136
local_shapes=comm.allgather(B_local.shape[0] * B_local.shape[1]),
3237
base_comm=comm,
3338
partition=pylops_mpi.Partition.SCATTER)
3439
B_dist.local_array[:] = B_local.flatten()
3540

36-
Aop = MPIMatrixMult(A_local, M, base_comm=comm)
41+
Aop = MPISummaMatrixMult(A_local, M, base_comm=comm)
3742
C_dist = Aop @ B_dist
3843
Z_dist = Aop.H @ C_dist
3944

40-
C = MPIMatrixMult.block_gather(C_dist, (N,M), (N,M), comm)
41-
Z = MPIMatrixMult.block_gather(Z_dist, (K,M), (K,M), comm)
45+
C = block_gather(C_dist, (N,M), (N,M), comm)
46+
Z = block_gather(Z_dist, (K,M), (K,M), comm)
4247
if rank == 0 :
4348
C_correct = np.allclose(A_data @ B_data, C)
4449
print("C expected: ", C_correct)

0 commit comments

Comments
 (0)