@@ -51,28 +51,27 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
5151 row_start_A = my_group * blk_rows_A
5252 row_end_A = min (M , row_start_A + blk_rows_A )
5353
54- blk_cols_BC = int (math .ceil (N / p_prime ))
55- col_start_B = my_layer * blk_cols_BC
56- col_end_B = min (N , col_start_B + blk_cols_BC )
57- local_col_B_len = max (0 , col_end_B - col_start_B )
58-
54+ blk_cols_X = int (math .ceil (N / p_prime ))
55+ col_start_X = my_layer * blk_cols_X
56+ col_end_X = min (N , col_start_X + blk_cols_X )
57+ local_col_X_len = max (0 , col_end_X - col_start_X )
5958
6059 A_glob_real = np .arange (M * K , dtype = base_float_dtype ).reshape (M , K )
6160 A_glob_imag = np .arange (M * K , dtype = base_float_dtype ).reshape (M , K ) * 0.5
6261 A_glob = (A_glob_real + cmplx * A_glob_imag ).astype (dtype )
6362
64- B_glob_real = np .arange (K * N , dtype = base_float_dtype ).reshape (K , N )
65- B_glob_imag = np .arange (K * N , dtype = base_float_dtype ).reshape (K , N ) * 0.7
66- B_glob = (B_glob_real + cmplx * B_glob_imag ).astype (dtype )
63+ X_glob_real = np .arange (K * N , dtype = base_float_dtype ).reshape (K , N )
64+ X_glob_imag = np .arange (K * N , dtype = base_float_dtype ).reshape (K , N ) * 0.7
65+ X_glob = (X_glob_real + cmplx * X_glob_imag ).astype (dtype )
6766
6867 A_p = A_glob [row_start_A :row_end_A ,:]
69- B_p = B_glob [:,col_start_B : col_end_B ]
68+ X_p = X_glob [:,col_start_X : col_end_X ]
7069
71- # Create SUMMAMatrixMult operator
70+ # Create MPIMatrixMult operator
7271 Aop = MPIMatrixMult (A_p , N , base_comm = comm , dtype = dtype_str )
7372
7473 # Create DistributedArray for input x (representing B flattened)
75- all_local_col_len = comm .allgather (local_col_B_len )
74+ all_local_col_len = comm .allgather (local_col_X_len )
7675 total_cols = np .sum (all_local_col_len )
7776
7877 x_dist = DistributedArray (
@@ -84,49 +83,34 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
8483 dtype = dtype
8584 )
8685
87- x_dist .local_array [:] = B_p .ravel ()
86+ x_dist .local_array [:] = X_p .ravel ()
8887
8988 # Forward operation: y = A @ B (distributed)
9089 y_dist = Aop @ x_dist
91-
9290 # Adjoint operation: xadj = A.H @ y (distributed)
9391 xadj_dist = Aop .H @ y_dist
9492
95- y_loc = A_glob @ B_glob
96- xadj_loc = A_glob .conj ().T @ y_loc
97-
98- col_start_C_dist = my_layer * blk_cols_BC
99- col_end_C_dist = min (N , col_start_C_dist + blk_cols_BC )
100- my_own_cols_C_dist = max (0 , col_end_C_dist - col_start_C_dist )
101- expected_y_shape = (M * my_own_cols_C_dist ,)
93+ y = y_dist .asarray (masked = True )
94+ y = y .reshape (p_prime , M , blk_cols_X )
10295
103- assert y_dist .local_array .shape == expected_y_shape , (
104- f"Rank { rank } : y_dist shape { y_dist .local_array .shape } != expected { expected_y_shape } "
105- )
96+ xadj = xadj_dist .asarray (masked = True )
97+ xadj = xadj .reshape (p_prime , K , blk_cols_X )
10698
107- if y_dist . local_array . size > 0 and y_loc is not None and y_loc . size > 0 :
108- expected_y_slice = y_loc [:, col_start_C_dist : col_end_C_dist ]
99+ if rank == 0 :
100+ y_loc = ( A_glob @ X_glob ). squeeze ()
109101 assert_allclose (
110- y_dist . local_array ,
111- expected_y_slice . ravel () ,
102+ y ,
103+ y_loc ,
112104 rtol = np .finfo (np .dtype (dtype )).resolution ,
113105 err_msg = f"Rank { rank } : Forward verification failed."
114106 )
115107
116- # Verify adjoint operation (xadj = A.H @ y)
117- expected_xadj_shape = (K * my_own_cols_C_dist ,)
118- assert xadj_dist .local_array .shape == expected_xadj_shape , (
119- f"Rank { rank } : z_dist shape { xadj_dist .local_array .shape } != expected { expected_xadj_shape } "
120- )
121-
122- # Verify adjoint result values
123- if xadj_dist .local_array .size > 0 and xadj_loc is not None and xadj_loc .size > 0 :
124- expected_xadj_slice = xadj_loc [:, col_start_C_dist :col_end_C_dist ]
108+ xadj_loc = (A_glob .conj ().T @ y_loc .conj ()).conj ().squeeze ()
125109 assert_allclose (
126- xadj_dist . local_array ,
127- expected_xadj_slice . ravel () ,
110+ xadj ,
111+ xadj_loc ,
128112 rtol = np .finfo (np .dtype (dtype )).resolution ,
129- err_msg = f"Rank { rank } : Adjoint verification failed."
113+ err_msg = f"Rank { rank } : Ajoint verification failed."
130114 )
131115
132116 group_comm .Free ()
0 commit comments