88from mpi4py import MPI
99import pytest
1010
11+ from pylops .basicoperators import FirstDerivative , Identity
1112from pylops_mpi import DistributedArray , Partition
12- from pylops_mpi .basicoperators . MatrixMult import MPIMatrixMult
13+ from pylops_mpi .basicoperators import MPIMatrixMult , MPIBlockDiag
1314
1415np .random .seed (42 )
1516base_comm = MPI .COMM_WORLD
1920# M, K, N are matrix dimensions A(N,K), B(K,M)
2021# P_prime will be ceil(sqrt(size)).
2122test_params = [
22- pytest .param (37 , 37 , 37 , "float32 " , id = "f32_37_37_37" ),
23+ pytest .param (37 , 37 , 37 , "float64 " , id = "f32_37_37_37" ),
2324 pytest .param (50 , 30 , 40 , "float64" , id = "f64_50_30_40" ),
2425 pytest .param (22 , 20 , 16 , "complex64" , id = "c64_22_20_16" ),
2526 pytest .param (3 , 4 , 5 , "float32" , id = "f32_3_4_5" ),
2627 pytest .param (1 , 2 , 1 , "float64" , id = "f64_1_2_1" ,),
2728 pytest .param (2 , 1 , 3 , "float32" , id = "f32_2_1_3" ,),
2829]
2930
31+ def _reorganize_local_matrix (x_dist , N , M , blk_cols , p_prime ):
32+ """Re-organize distributed array in local matrix
33+ """
34+ x = x_dist .asarray (masked = True )
35+ col_counts = [min (blk_cols , M - j * blk_cols ) for j in range (p_prime )]
36+ x_blocks = []
37+ offset = 0
38+ for cnt in col_counts :
39+ block_size = N * cnt
40+ x_block = x [offset : offset + block_size ]
41+ if len (x_block ) != 0 :
42+ x_blocks .append (
43+ x_block .reshape (N , cnt )
44+ )
45+ offset += block_size
46+ x = np .hstack (x_blocks )
47+ return x
48+
3049
3150@pytest .mark .mpi (min_size = 1 )
32- @pytest .mark .parametrize ("M , K, N , dtype_str" , test_params )
51+ @pytest .mark .parametrize ("N , K, M , dtype_str" , test_params )
3352def test_MPIMatrixMult (N , K , M , dtype_str ):
3453 dtype = np .dtype (dtype_str )
3554
3655 cmplx = 1j if np .issubdtype (dtype , np .complexfloating ) else 0
3756 base_float_dtype = np .float32 if dtype == np .complex64 else np .float64
3857
39- comm , rank , row_id , col_id , is_active = MPIMatrixMult .active_grid_comm (base_comm , N , M )
58+ comm , rank , row_id , col_id , is_active = \
59+ MPIMatrixMult .active_grid_comm (base_comm , N , M )
4060 if not is_active : return
4161
4262 size = comm .Get_size ()
4363 p_prime = math .isqrt (size )
64+ cols_id = comm .allgather (col_id )
4465
4566 # Calculate local matrix dimensions
4667 blk_rows_A = int (math .ceil (N / p_prime ))
@@ -52,6 +73,7 @@ def test_MPIMatrixMult(N, K, M, dtype_str):
5273 col_end_X = min (M , col_start_X + blk_cols_X )
5374 local_col_X_len = max (0 , col_end_X - col_start_X )
5475
76+ # Fill local matrices
5577 A_glob_real = np .arange (N * K , dtype = base_float_dtype ).reshape (N , K )
5678 A_glob_imag = np .arange (N * K , dtype = base_float_dtype ).reshape (N , K ) * 0.5
5779 A_glob = (A_glob_real + cmplx * A_glob_imag ).astype (dtype )
@@ -88,32 +110,8 @@ def test_MPIMatrixMult(N, K, M, dtype_str):
88110 xadj_dist = Aop .H @ y_dist
89111
90112 # Re-organize in local matrix
91- y = y_dist .asarray (masked = True )
92- col_counts = [min (blk_cols_X , M - j * blk_cols_X ) for j in range (p_prime )]
93- y_blocks = []
94- offset = 0
95- for cnt in col_counts :
96- block_size = N * cnt
97- y_block = y [offset : offset + block_size ]
98- if len (y_block ) != 0 :
99- y_blocks .append (
100- y_block .reshape (N , cnt )
101- )
102- offset += block_size
103- y = np .hstack (y_blocks )
104-
105- xadj = xadj_dist .asarray (masked = True )
106- xadj_blocks = []
107- offset = 0
108- for cnt in col_counts :
109- block_size = K * cnt
110- xadj_blk = xadj [offset : offset + block_size ]
111- if len (xadj_blk ) != 0 :
112- xadj_blocks .append (
113- xadj_blk .reshape (K , cnt )
114- )
115- offset += block_size
116- xadj = np .hstack (xadj_blocks )
113+ y = _reorganize_local_matrix (y_dist , N , M , blk_cols_X , p_prime )
114+ xadj = _reorganize_local_matrix (xadj_dist , K , M , blk_cols_X , p_prime )
117115
118116 if rank == 0 :
119117 y_loc = A_glob @ X_glob
@@ -129,5 +127,36 @@ def test_MPIMatrixMult(N, K, M, dtype_str):
129127 xadj .squeeze (),
130128 xadj_loc .squeeze (),
131129 rtol = np .finfo (np .dtype (dtype )).resolution ,
132- err_msg = f"Rank { rank } : Ajoint verification failed."
133- )
130+ err_msg = f"Rank { rank } : Adjoint verification failed."
131+ )
132+
133+ # Chain with another operator
134+ Dop = FirstDerivative (dims = (N , col_end_X - col_start_X ),
135+ axis = 0 , dtype = dtype )
136+ DBop = MPIBlockDiag (ops = [Dop , ], base_comm = comm , mask = cols_id )
137+ Op = DBop @ Aop
138+
139+ y1_dist = Op @ x_dist
140+ xadj1_dist = Op .H @ y1_dist
141+
142+ # Re-organize in local matrix
143+ y1 = _reorganize_local_matrix (y1_dist , N , M , blk_cols_X , p_prime )
144+ xadj1 = _reorganize_local_matrix (xadj1_dist , K , M , blk_cols_X , p_prime )
145+
146+ if rank == 0 :
147+ Dop_glob = FirstDerivative (dims = (N , M ), axis = 0 , dtype = dtype )
148+ y1_loc = (Dop_glob @ (A_glob @ X_glob ).ravel ()).reshape (N , M )
149+ assert_allclose (
150+ y1 .squeeze (),
151+ y1_loc .squeeze (),
152+ rtol = np .finfo (np .dtype (dtype )).resolution ,
153+ err_msg = f"Rank { rank } : Forward verification failed."
154+ )
155+
156+ xadj1_loc = A_glob .conj ().T @ (Dop_glob .H @ y1_loc .ravel ()).reshape (N , M )
157+ assert_allclose (
158+ xadj1 .squeeze (),
159+ xadj1_loc .squeeze (),
160+ rtol = np .finfo (np .dtype (dtype )).resolution ,
161+ err_msg = f"Rank { rank } : Adjoint verification failed."
162+ )
0 commit comments