@@ -35,27 +35,26 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
3535    cmplx  =  1j  if  np .issubdtype (dtype , np .complexfloating ) else  0 
3636    base_float_dtype  =  np .float32  if  dtype  ==  np .complex64  else  np .float64 
3737
38-     P_prime  =  int (math .ceil (math .sqrt (size )))
39-     C  =  int (math .ceil (size  /  P_prime ))
40-     assert  P_prime  *  C  >=  size    # Ensure process grid covers all processes 
38+     p_prime  =  int (math .ceil (math .sqrt (size )))
39+     C  =  int (math .ceil (size  /  p_prime ))
40+     assert  p_prime  *  C  ==  size 
4141
42-     my_group  =  rank  %  P_prime 
43-     my_layer  =  rank  //  P_prime 
42+     my_group  =  rank  %  p_prime 
43+     my_layer  =  rank  //  p_prime 
4444
4545    # Create sub-communicators 
4646    layer_comm  =  comm .Split (color = my_layer , key = my_group )
4747    group_comm  =  comm .Split (color = my_group , key = my_layer )
4848
4949    # Calculate local matrix dimensions 
50-     blk_rows_A  =  int (math .ceil (M  /  P_prime ))
50+     blk_rows_A  =  int (math .ceil (M  /  p_prime ))
5151    row_start_A  =  my_group  *  blk_rows_A 
5252    row_end_A  =  min (M , row_start_A  +  blk_rows_A )
53-     my_own_rows_A  =  max (0 , row_end_A  -  row_start_A )
5453
55-     blk_cols_BC  =  int (math .ceil (N  /  P_prime ))
54+     blk_cols_BC  =  int (math .ceil (N  /  p_prime ))
5655    col_start_B  =  my_layer  *  blk_cols_BC 
5756    col_end_B  =  min (N , col_start_B  +  blk_cols_BC )
58-     my_own_cols_B  =  max (0 , col_end_B  -  col_start_B )
57+     local_col_B_len  =  max (0 , col_end_B  -  col_start_B )
5958
6059
6160    A_glob_real  =  np .arange (M  *  K , dtype = base_float_dtype ).reshape (M , K )
@@ -73,33 +72,28 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
7372    Aop  =  MPIMatrixMult (A_p , N , base_comm = comm , dtype = dtype_str )
7473
7574    # Create DistributedArray for input x (representing B flattened) 
76-     all_my_own_cols_B  =  comm .allgather (my_own_cols_B )
77-     total_cols  =  np .sum (all_my_own_cols_B )
75+     all_local_col_len  =  comm .allgather (local_col_B_len )
76+     total_cols  =  np .sum (all_local_col_len )
7877
7978    x_dist  =  DistributedArray (
8079        global_shape = (K  *  total_cols ),
81-         local_shapes = [K  *  cl_b  for  cl_b  in  all_my_own_cols_B ],
80+         local_shapes = [K  *  cl_b  for  cl_b  in  all_local_col_len ],
8281        partition = Partition .SCATTER ,
8382        base_comm = comm ,
83+         mask = [i  //  p_prime  for  i  in  range (size )],
8484        dtype = dtype 
8585    )
8686
87-     if  B_p .size  >  0 :
88-         x_dist .local_array [:] =  B_p .ravel ()
89-     else :
90-         assert  x_dist .local_array .size  ==  0 , (
91-             f"Rank { rank }  : B_p empty but x_dist.local_array not empty " 
92-             f"(size { x_dist .local_array .size }  )" 
93-         )
87+     x_dist .local_array [:] =  B_p .ravel ()
9488
9589    # Forward operation: y = A @ B (distributed) 
9690    y_dist  =  Aop  @ x_dist 
9791
98-     # Adjoint operation: z  = A.H @ y (distributed y representing C ) 
99-     z_dist  =  Aop .H  @ y_dist 
92+     # Adjoint operation: xadj  = A.H @ y (distributed) 
93+     xadj_dist  =  Aop .H  @ y_dist 
10094
101-     C_true  =  A_glob  @ B_glob 
102-     Z_true  =  A_glob .conj ().T  @ C_true 
95+     y_loc  =  A_glob  @ B_glob 
96+     xadj_loc  =  A_glob .conj ().T  @ y_loc 
10397
10498    col_start_C_dist    =  my_layer  *  blk_cols_BC 
10599    col_end_C_dist      =  min (N , col_start_C_dist  +  blk_cols_BC )
@@ -110,27 +104,27 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
110104        f"Rank { rank }  : y_dist shape { y_dist .local_array .shape }   != expected { expected_y_shape }  " 
111105    )
112106
113-     if  y_dist .local_array .size  >  0  and  C_true  is  not   None  and  C_true .size  >  0 :
114-         expected_y_slice  =  C_true [:, col_start_C_dist :col_end_C_dist ]
107+     if  y_dist .local_array .size  >  0  and  y_loc  is  not   None  and  y_loc .size  >  0 :
108+         expected_y_slice  =  y_loc [:, col_start_C_dist :col_end_C_dist ]
115109        assert_allclose (
116110            y_dist .local_array ,
117111            expected_y_slice .ravel (),
118112            rtol = np .finfo (np .dtype (dtype )).resolution ,
119113            err_msg = f"Rank { rank }  : Forward verification failed." 
120114        )
121115
122-     # Verify adjoint operation (z  = A.H @ y) 
123-     expected_z_shape  =  (K  *  my_own_cols_C_dist ,)
124-     assert  z_dist .local_array .shape  ==  expected_z_shape , (
125-         f"Rank { rank }  : z_dist shape { z_dist .local_array .shape }   != expected { expected_z_shape }  " 
116+     # Verify adjoint operation (xadj  = A.H @ y) 
117+     expected_xadj_shape  =  (K  *  my_own_cols_C_dist ,)
118+     assert  xadj_dist .local_array .shape  ==  expected_xadj_shape , (
119+         f"Rank { rank }  : z_dist shape { xadj_dist .local_array .shape }   != expected { expected_xadj_shape }  " 
126120    )
127121
128122    # Verify adjoint result values 
129-     if  z_dist .local_array .size  >  0  and  Z_true   is  not   None  and  Z_true .size  >  0 :
130-         expected_z_slice  =  Z_true [:, col_start_C_dist :col_end_C_dist ]
123+     if  xadj_dist .local_array .size  >  0  and  xadj_loc    is  not   None  and  xadj_loc   .size  >  0 :
124+         expected_xadj_slice  =  xadj_loc   [:, col_start_C_dist :col_end_C_dist ]
131125        assert_allclose (
132-             z_dist .local_array ,
133-             expected_z_slice .ravel (),
126+             xadj_dist .local_array ,
127+             expected_xadj_slice .ravel (),
134128            rtol = np .finfo (np .dtype (dtype )).resolution ,
135129            err_msg = f"Rank { rank }  : Adjoint verification failed." 
136130        )
0 commit comments