| 
11 | 11 | from mpi4py import MPI  | 
12 | 12 | 
 
  | 
13 | 13 | from pylops_mpi import DistributedArray, Partition  | 
14 |  | -from pylops_mpi.basicoperators.MatrixMult import MPISUMMAMatrixMult  | 
 | 14 | +from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult  | 
15 | 15 | 
 
  | 
16 | 16 | np.random.seed(42)  | 
17 | 17 | 
 
  | 
 | 
22 | 22 | P_prime = int(math.ceil(math.sqrt(n_procs)))  | 
23 | 23 | C = int(math.ceil(n_procs / P_prime))  | 
24 | 24 | 
 
  | 
25 |  | -if P_prime * C < n_procs:  | 
 | 25 | +if (P_prime * C) != n_procs:  | 
26 | 26 |     print("No. of procs has to be a square number")  | 
27 | 27 |     exit(-1)  | 
28 | 28 | 
 
  | 
29 | 29 | # matrix dims  | 
30 | 30 | M = 32  | 
31 |  | -K = 32  | 
32 |  | -N = 35  | 
 | 31 | +K = 35  | 
 | 32 | +N = 37  | 
33 | 33 | 
 
  | 
34 |  | -blk_rows = int(math.ceil(M / P_prime))  | 
35 |  | -blk_cols = int(math.ceil(N / P_prime))  | 
 | 34 | +A = np.random.rand(M * K).astype(dtype=np.float32).reshape(M, K)  | 
 | 35 | +B = np.random.rand(K * N).astype(dtype=np.float32).reshape(K, N)  | 
36 | 36 | 
 
  | 
37 | 37 | my_group = rank % P_prime  | 
38 | 38 | my_layer = rank // P_prime  | 
 | 
41 | 41 | layer_comm = comm.Split(color=my_layer, key=my_group)  # all procs in same layer  | 
42 | 42 | group_comm = comm.Split(color=my_group, key=my_layer)  # all procs in same group  | 
43 | 43 | 
 
  | 
44 |  | -# Each rank will end up with:  | 
45 |  | -#   A_p: shape (my_own_rows, K)  | 
46 |  | -#   B_p: shape (K, my_own_cols)  | 
47 |  | -# where  | 
 | 44 | + | 
 | 45 | +#Each rank will end up with:  | 
 | 46 | +#      - :math:`A_{p} \in \mathbb{R}^{\text{my\_own\_rows}\times K}`  | 
 | 47 | +#      - :math:`B_{p} \in \mathbb{R}^{K\times \text{my\_own\_cols}}`  | 
 | 48 | +#    where  | 
 | 49 | +blk_rows = int(math.ceil(M / P_prime))  | 
48 | 50 | row_start = my_group * blk_rows  | 
49 | 51 | row_end = min(M, row_start + blk_rows)  | 
50 | 52 | my_own_rows = row_end - row_start  | 
51 | 53 | 
 
  | 
52 |  | -col_start = my_group * blk_cols  # note: same my_group index on cols  | 
 | 54 | +blk_cols = int(math.ceil(N / P_prime))  | 
 | 55 | +col_start = my_layer * blk_cols  | 
53 | 56 | col_end = min(N, col_start + blk_cols)  | 
54 | 57 | my_own_cols = col_end - col_start  | 
55 | 58 | 
 
  | 
56 |  | -# ======================= BROADCASTING THE SLICES =======================  | 
57 |  | -if rank == 0:  | 
58 |  | -    A = np.arange(M * K, dtype=np.float32).reshape(M, K)  | 
59 |  | -    B = np.arange(K * N, dtype=np.float32).reshape(K, N)  | 
60 |  | -    for dest in range(n_procs):  | 
61 |  | -        pg = dest % P_prime  | 
62 |  | -        rs = pg * blk_rows;  | 
63 |  | -        re = min(M, rs + blk_rows)  | 
64 |  | -        cs = pg * blk_cols;  | 
65 |  | -        ce = min(N, cs + blk_cols)  | 
66 |  | -        a_block, b_block = A[rs:re, :], B[:, cs:ce]  | 
67 |  | -        if dest == 0:  | 
68 |  | -            A_p, B_p = a_block, b_block  | 
69 |  | -        else:  | 
70 |  | -            comm.Send(a_block, dest=dest, tag=100 + dest)  | 
71 |  | -            comm.Send(b_block, dest=dest, tag=200 + dest)  | 
72 |  | -else:  | 
73 |  | -    A_p = np.empty((my_own_rows, K), dtype=np.float32)  | 
74 |  | -    B_p = np.empty((K, my_own_cols), dtype=np.float32)  | 
75 |  | -    comm.Recv(A_p, source=0, tag=100 + rank)  | 
76 |  | -    comm.Recv(B_p, source=0, tag=200 + rank)  | 
77 | 59 | 
 
  | 
78 |  | -comm.Barrier()  | 
 | 60 | +rs = (rank % P_prime) * blk_rows  | 
 | 61 | +re = min(M, rs + blk_rows)  | 
79 | 62 | 
 
  | 
80 |  | -Aop = MPISUMMAMatrixMult(A_p, N)  | 
 | 63 | +cs = (rank // P_prime) * blk_cols  | 
 | 64 | +ce = min(N, cs + blk_cols)  | 
 | 65 | +A_p, B_p = A[rs:re, :].copy(), B[:, cs:ce].copy()  | 
 | 66 | + | 
 | 67 | +Aop = MPIMatrixMult(A_p, N, dtype="float32")  | 
81 | 68 | col_lens = comm.allgather(my_own_cols)  | 
82 | 69 | total_cols =  np.sum(col_lens)  | 
83 | 70 | x = DistributedArray(global_shape=K * total_cols,  | 
84 | 71 |                      local_shapes=[K * col_len for col_len in col_lens],  | 
85 | 72 |                      partition=Partition.SCATTER,  | 
86 | 73 |                      mask=[i % P_prime for i in range(comm.Get_size())],  | 
87 |  | -                     dtype=np.float32)  | 
 | 74 | +                     base_comm=comm,  | 
 | 75 | +                     dtype="float32")  | 
88 | 76 | x[:] = B_p.flatten()  | 
89 | 77 | y = Aop @ x  | 
90 | 78 | 
 
  | 
91 | 79 | # ======================= VERIFICATION =================-=============  | 
92 |  | -A = np.arange(M * K).reshape(M, K).astype(np.float32)  | 
93 |  | -B = np.arange(K * N).reshape(K, N).astype(np.float32)  | 
94 |  | -C_true = A @ B  | 
95 |  | -Z_true = (A.T.dot(C_true.conj())).conj()  | 
 | 80 | +y_loc = A @ B  | 
 | 81 | +xadj_loc = (A.T.dot(y_loc.conj())).conj()  | 
96 | 82 | 
 
  | 
97 |  | -col_start = my_layer * blk_cols  # note: same my_group index on cols  | 
98 |  | -col_end = min(N, col_start + blk_cols)  | 
99 |  | -my_own_cols = col_end - col_start  | 
100 |  | -expected_y = C_true[:, col_start:col_end].flatten()  | 
101 | 83 | 
 
  | 
102 |  | -xadj = Aop.H @ y  | 
 | 84 | +expected_y_loc = y_loc[:, col_start:col_end].flatten().astype(np.float32)  | 
 | 85 | +expected_xadj_loc = xadj_loc[:, col_start:col_end].flatten().astype(np.float32)  | 
103 | 86 | 
 
  | 
104 |  | -if not np.allclose(y.local_array, expected_y, atol=1e-6, rtol=1e-14):  | 
 | 87 | +xadj = Aop.H @ y  | 
 | 88 | +if not np.allclose(y.local_array, expected_y_loc, rtol=1e-6):  | 
105 | 89 |     print(f"RANK {rank}: FORWARD VERIFICATION FAILED")  | 
106 |  | -    print(f'{rank} local: {y.local_array}, expected: {C_true[:, col_start:col_end]}')  | 
 | 90 | +    print(f'{rank} local: {y.local_array}, expected: {y_loc[:, col_start:col_end]}')  | 
107 | 91 | else:  | 
108 | 92 |     print(f"RANK {rank}: FORWARD VERIFICATION PASSED")  | 
109 | 93 | 
 
  | 
110 |  | -expected_z = Z_true[:, col_start:col_end].flatten()  | 
111 |  | -if not np.allclose(xadj.local_array, expected_z, atol=1e-6, rtol=1e-14):  | 
 | 94 | +if not np.allclose(xadj.local_array, expected_xadj_loc, rtol=1e-6):  | 
112 | 95 |     print(f"RANK {rank}: ADJOINT VERIFICATION FAILED")  | 
113 |  | -    print(f'{rank} local: {xadj.local_array}, expected: {Z_true[:, col_start:col_end]}')  | 
 | 96 | +    print(f'{rank} local: {xadj.local_array}, expected: {xadj_loc[:, col_start:col_end]}')  | 
114 | 97 | else:  | 
115 | 98 |     print(f"RANK {rank}: ADJOINT VERIFICATION PASSED")  | 
 | 99 | + | 
0 commit comments