1- import numpy as np
21from mpi4py import MPI
32import math
43import pylops_mpi
4+ from pylops_mpi .basicoperators .MatrixMult import MPIMatrixMult
5+ import numpy as np
56
67comm = MPI .COMM_WORLD
78rank = comm .Get_rank ()
89size = comm .Get_size ()
910
10- M = 8 #512
11- N = 8 #512
12- K = 8 #512
11+ N = 8
12+ M = 8
13+ K = 8
1314
14- A_shape = (M , K )
15- B_shape = (K ,N )
16- C_shape = (M , N )
15+ A_shape = (N , K )
16+ B_shape = (K , M )
17+ C_shape = (N , M )
1718
1819p_prime = math .isqrt (size )
19- assert p_prime * p_prime == size , "Number of processes must be a perfect square"
20+ assert p_prime * p_prime == size , "Number of processes must be a perfect square"
2021
21- # Create A with 2D block-cyclic structure
22- A_data = np .arange (int (A_shape [0 ]* A_shape [1 ])).reshape (A_shape )
23- A = A_data .reshape (p_prime , M // p_prime , p_prime , K // p_prime ).transpose (1 , 0 , 2 , 3 ).reshape (M // p_prime , - 1 )
22+ A_data = np .arange (int (A_shape [0 ] * A_shape [1 ])).reshape (A_shape )
23+ B_data = np .arange (int (B_shape [0 ] * B_shape [1 ])).reshape (B_shape )
2424
25- # Create B with 2D block-cyclic structure
26- B_data = np . arange ( int ( B_shape [ 0 ] * B_shape [ 1 ])). reshape ( B_shape )
27- B = B_data . reshape ( p_prime , K // p_prime , p_prime , N // p_prime ). transpose ( 1 , 0 , 2 , 3 ). reshape ( K // p_prime , - 1 )
25+ N_starts , N_ends = MPIMatrixMult . block_distribute ( N , p_prime )
26+ M_starts , M_ends = MPIMatrixMult . block_distribute ( M , p_prime )
27+ K_starts , K_ends = MPIMatrixMult . block_distribute ( K , p_prime )
2828
29- A_dist = pylops_mpi .DistributedArray .to_dist (A ,
30- partition = pylops_mpi .Partition .SCATTER ,
31- axis = 1 )
32- B_dist = pylops_mpi .DistributedArray .to_dist (B ,
33- partition = pylops_mpi .Partition .SCATTER ,
34- axis = 1 )
29+ i , j = divmod (rank , p_prime )
30+ A_local = A_data [N_starts [i ]:N_ends [i ], K_starts [j ]:K_ends [j ]]
31+ B_local = B_data [K_starts [i ]:K_ends [i ], M_starts [j ]:M_ends [j ]]
3532
36- C_dist = pylops_mpi .DistributedArray (global_shape = (M // p_prime , N * p_prime ),
37- partition = pylops_mpi .Partition .SCATTER ,
38- axis = 1 )
39- if rank == 0 : print (A_dist .local_array )
33+ B_dist = pylops_mpi .DistributedArray (global_shape = (K * M ),
34+ local_shapes = comm .allgather (B_local .shape [0 ] * B_local .shape [1 ]),
35+ base_comm = comm ,
36+ partition = pylops_mpi .Partition .SCATTER )
37+ B_dist .local_array [:] = B_local .flatten ()
4038
41- i , j = divmod (rank , p_prime )
42- row_comm = comm .Split (color = i , key = j )
43- col_comm = comm .Split (color = j , key = i )
44-
45- c_local = np .zeros ((M // p_prime , N // p_prime ))
46- for k in range (p_prime ):
47- Atemp = A_dist .local_array .copy () if j == k else np .empty_like (A_dist .local_array )
48- Btemp = B_dist .local_array .copy () if i == k else np .empty_like (B_dist .local_array )
49- rootA = i * p_prime + k ; rootB = k * p_prime + j
50- row_comm .Bcast ([Atemp ,MPI .FLOAT ],root = k )
51- col_comm .Bcast ([Btemp ,MPI .FLOAT ],root = k )
52- # print(f"[Rank {rank}] iter{k} after : received A from {rootA}, B from {rootB}, A0={Atemp.flat[0]},B0={Btemp.flat[0]}")
53- c_local += Atemp @ Btemp
54-
55- C_dist .local_array [:] = c_local
56- C_temp = C_dist .asarray ().reshape ((M ,N ))
57- C = C_temp .reshape (M // p_prime , p_prime , p_prime , N // p_prime ).transpose (1 , 0 , 2 , 3 ).reshape (M , N )
39+ print (rank , A_local .shape )
40+ Aop = MPIMatrixMult (A_local , M , base_comm = comm )
41+ C_dist = Aop @ B_dist
42+ C_temp = C_dist .asarray ().reshape ((N , M ))
43+ C = C_temp .reshape (N // p_prime , p_prime , p_prime , M // p_prime ).transpose (1 , 0 , 2 , 3 ).reshape (N , M )
5844
5945if rank == 0 :
60- print ("expected:\n " ,A_data @ B_data )
46+ # print("expected:\n",np.allclose(A_data @ B_data, C))
47+ print ("expected:\n " , A_data @ B_data )
6148 print ("calculated:\n " ,C )
0 commit comments