|
3 | 3 | from mpi4py import MPI |
4 | 4 |
|
5 | 5 | import pylops_mpi |
6 | | -from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult |
| 6 | +from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, |
| 7 | + block_gather, |
| 8 | + MPISummaMatrixMult) |
7 | 9 |
|
8 | 10 | comm = MPI.COMM_WORLD |
9 | 11 | rank = comm.Get_rank() |
|
23 | 25 | A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape) |
24 | 26 | B_data = np.arange(int(B_shape[0] * B_shape[1])).reshape(B_shape) |
25 | 27 |
|
26 | | -i, j = divmod(rank, p_prime) |
27 | | -A_local, (N_new, K_new) = MPIMatrixMult.block_distribute(A_data, i, j,comm) |
28 | | -B_local, (K_new, M_new) = MPIMatrixMult.block_distribute(B_data, i, j,comm) |
| 28 | +A_slice = local_block_spit(A_shape, rank, comm) |
| 29 | +B_slice = local_block_spit(B_shape, rank, comm) |
| 30 | +A_local = A_data[A_slice] |
| 31 | +B_local = B_data[B_slice] |
| 32 | +# A_local, (N_new, K_new) = block_distribute(A_data,rank, comm) |
| 33 | +# B_local, (K_new, M_new) = block_distribute(B_data,rank, comm) |
29 | 34 |
|
30 | 35 | B_dist = pylops_mpi.DistributedArray(global_shape=(K * M), |
31 | 36 | local_shapes=comm.allgather(B_local.shape[0] * B_local.shape[1]), |
32 | 37 | base_comm=comm, |
33 | 38 | partition=pylops_mpi.Partition.SCATTER) |
34 | 39 | B_dist.local_array[:] = B_local.flatten() |
35 | 40 |
|
36 | | -Aop = MPIMatrixMult(A_local, M, base_comm=comm) |
| 41 | +Aop = MPISummaMatrixMult(A_local, M, base_comm=comm) |
37 | 42 | C_dist = Aop @ B_dist |
38 | 43 | Z_dist = Aop.H @ C_dist |
39 | 44 |
|
40 | | -C = MPIMatrixMult.block_gather(C_dist, (N,M), (N,M), comm) |
41 | | -Z = MPIMatrixMult.block_gather(Z_dist, (K,M), (K,M), comm) |
| 45 | +C = block_gather(C_dist, (N,M), (N,M), comm) |
| 46 | +Z = block_gather(Z_dist, (K,M), (K,M), comm) |
42 | 47 | if rank == 0 : |
43 | 48 | C_correct = np.allclose(A_data @ B_data, C) |
44 | 49 | print("C expected: ", C_correct) |
|
0 commit comments