|
15 | 15 |
|
16 | 16 | np.random.seed(42) |
17 | 17 |
|
18 | | -comm = MPI.COMM_WORLD |
19 | | -rank = comm.Get_rank() |
20 | | -nProcs = comm.Get_size() |
| 18 | +comm = MPI.COMM_WORLD |
| 19 | +rank = comm.Get_rank() |
| 20 | +n_procs = comm.Get_size() |
21 | 21 |
|
| 22 | +P_prime = int(math.ceil(math.sqrt(n_procs))) |
| 23 | +C = int(math.ceil(n_procs / P_prime)) |
22 | 24 |
|
23 | | -P_prime = int(math.ceil(math.sqrt(nProcs))) |
24 | | -C = int(math.ceil(nProcs / P_prime)) |
25 | | - |
26 | | -if P_prime * C < nProcs: |
| 25 | +if P_prime * C < n_procs: |
27 | 26 | print("No. of procs has to be a square number") |
28 | 27 | exit(-1) |
29 | 28 |
|
|
39 | 38 | my_layer = rank // P_prime |
40 | 39 |
|
41 | 40 | # sub‐communicators |
42 | | -layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer |
43 | | -group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group |
| 41 | +layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer |
| 42 | +group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group |
44 | 43 |
|
45 | 44 | # Each rank will end up with: |
46 | 45 | # A_p: shape (my_own_rows, K) |
47 | 46 | # B_p: shape (K, my_own_cols) |
48 | 47 | # where |
49 | | -row_start = my_group * blk_rows |
50 | | -row_end = min(M, row_start + blk_rows) |
| 48 | +row_start = my_group * blk_rows |
| 49 | +row_end = min(M, row_start + blk_rows) |
51 | 50 | my_own_rows = row_end - row_start |
52 | 51 |
|
53 | | -col_start = my_group * blk_cols # note: same my_group index on cols |
54 | | -col_end = min(N, col_start + blk_cols) |
| 52 | +col_start = my_group * blk_cols # note: same my_group index on cols |
| 53 | +col_end = min(N, col_start + blk_cols) |
55 | 54 | my_own_cols = col_end - col_start |
56 | 55 |
|
57 | 56 | # ======================= BROADCASTING THE SLICES ======================= |
58 | 57 | if rank == 0: |
59 | | - A = np.arange(M*K, dtype=np.float32).reshape(M, K) |
60 | | - B = np.arange(K*N, dtype=np.float32).reshape(K, N) |
61 | | - for dest in range(nProcs): |
| 58 | + A = np.arange(M * K, dtype=np.float32).reshape(M, K) |
| 59 | + B = np.arange(K * N, dtype=np.float32).reshape(K, N) |
| 60 | + for dest in range(n_procs): |
62 | 61 | pg = dest % P_prime |
63 | | - rs = pg*blk_rows; re = min(M, rs+blk_rows) |
64 | | - cs = pg*blk_cols; ce = min(N, cs+blk_cols) |
65 | | - a_block , b_block = A[rs:re, :].copy(), B[:, cs:ce].copy() |
| 62 | + rs = pg * blk_rows; |
| 63 | + re = min(M, rs + blk_rows) |
| 64 | + cs = pg * blk_cols; |
| 65 | + ce = min(N, cs + blk_cols) |
| 66 | + a_block, b_block = A[rs:re, :].copy(), B[:, cs:ce].copy() |
66 | 67 | if dest == 0: |
67 | 68 | A_p, B_p = a_block, b_block |
68 | 69 | else: |
69 | | - comm.Send(a_block, dest=dest, tag=100+dest) |
70 | | - comm.Send(b_block, dest=dest, tag=200+dest) |
| 70 | + comm.Send(a_block, dest=dest, tag=100 + dest) |
| 71 | + comm.Send(b_block, dest=dest, tag=200 + dest) |
71 | 72 | else: |
72 | 73 | A_p = np.empty((my_own_rows, K), dtype=np.float32) |
73 | 74 | B_p = np.empty((K, my_own_cols), dtype=np.float32) |
74 | | - comm.Recv(A_p, source=0, tag=100+rank) |
75 | | - comm.Recv(B_p, source=0, tag=200+rank) |
| 75 | + comm.Recv(A_p, source=0, tag=100 + rank) |
| 76 | + comm.Recv(B_p, source=0, tag=200 + rank) |
76 | 77 |
|
77 | 78 | comm.Barrier() |
78 | 79 |
|
79 | | -Aop = MPISUMMAMatrixMult(A_p, N) |
80 | | -col_lens = comm.allgather(my_own_cols) |
| 80 | +Aop = MPISUMMAMatrixMult(A_p, N) |
| 81 | +col_lens = comm.allgather(my_own_cols) |
81 | 82 | total_cols = np.add.reduce(col_lens, 0) |
82 | 83 | x = DistributedArray(global_shape=K * total_cols, |
83 | 84 | local_shapes=[K * col_len for col_len in col_lens], |
|
88 | 89 | y = Aop @ x |
89 | 90 |
|
90 | 91 | # ======================= VERIFICATION =================-============= |
91 | | -A = np.arange(M*K).reshape(M, K).astype(np.float32) |
92 | | -B = np.arange(K*N).reshape(K, N).astype(np.float32) |
| 92 | +A = np.arange(M * K).reshape(M, K).astype(np.float32) |
| 93 | +B = np.arange(K * N).reshape(K, N).astype(np.float32) |
93 | 94 | C_true = A @ B |
94 | 95 | Z_true = (A.T.dot(C_true.conj())).conj() |
95 | 96 |
|
96 | | - |
97 | | -col_start = my_layer * blk_cols # note: same my_group index on cols |
98 | | -col_end = min(N, col_start + blk_cols) |
| 97 | +col_start = my_layer * blk_cols # note: same my_group index on cols |
| 98 | +col_end = min(N, col_start + blk_cols) |
99 | 99 | my_own_cols = col_end - col_start |
100 | | -expected_y = C_true[:,col_start:col_end].flatten() |
| 100 | +expected_y = C_true[:, col_start:col_end].flatten() |
101 | 101 |
|
102 | 102 | if not np.allclose(y.local_array, expected_y, atol=1e-6, rtol=1e-14): |
103 | 103 | print(f"RANK {rank}: FORWARD VERIFICATION FAILED") |
104 | | - print(f'{rank} local: {y.local_array}, expected: {C_true[:,col_start:col_end]}') |
| 104 | + print(f'{rank} local: {y.local_array}, expected: {C_true[:, col_start:col_end]}') |
105 | 105 | else: |
106 | 106 | print(f"RANK {rank}: FORWARD VERIFICATION PASSED") |
107 | 107 |
|
108 | 108 | z = Aop.H @ y |
109 | | -expected_z = Z_true[:,col_start:col_end].flatten() |
| 109 | +expected_z = Z_true[:, col_start:col_end].flatten() |
110 | 110 | if not np.allclose(z.local_array, expected_z, atol=1e-6, rtol=1e-14): |
111 | 111 | print(f"RANK {rank}: ADJOINT VERIFICATION FAILED") |
112 | | - print(f'{rank} local: {z.local_array}, expected: {Z_true[:,col_start:col_end]}') |
| 112 | + print(f'{rank} local: {z.local_array}, expected: {Z_true[:, col_start:col_end]}') |
113 | 113 | else: |
114 | 114 | print(f"RANK {rank}: ADJOINT VERIFICATION PASSED") |
0 commit comments