@@ -35,9 +35,9 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
3535    cmplx  =  1j  if  np .issubdtype (dtype , np .complexfloating ) else  0 
3636    base_float_dtype  =  np .float32  if  dtype  ==  np .complex64  else  np .float64 
3737
38-     p_prime  =  int ( math .ceil ( math . sqrt ( size )) )
39-     C  =  int ( math . ceil ( size   /   p_prime )) 
40-     assert  p_prime  *  C  ==  size 
38+     p_prime  =  math .isqrt ( size )
39+     C  =  p_prime 
40+     assert  p_prime  *  C  ==  size ,  f"Number of processes must be a square number, provided  { size }  instead..." 
4141
4242    my_group  =  rank  %  p_prime 
4343    my_layer  =  rank  //  p_prime 
@@ -90,25 +90,42 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
9090    # Adjoint operation: xadj = A.H @ y (distributed) 
9191    xadj_dist  =  Aop .H  @ y_dist 
9292
93-     y     =  y_dist .asarray (masked = True )
94-     y     =  y .reshape (p_prime , M , blk_cols_X )
93+     y  =  y_dist .asarray (masked = True )
94+     col_counts  =  [min (blk_cols_X , N  -  j  *  blk_cols_X ) for  j  in  range (p_prime )]
95+     y_blocks  =  []
96+     offset  =  0 
97+     for  cnt  in  col_counts :
98+         block_size  =  M  *  cnt 
99+         y_blocks .append (
100+             y [offset : offset  +  block_size ].reshape (M , cnt )
101+         )
102+         offset  +=  block_size 
103+     y  =  np .hstack (y_blocks )
95104
96105    xadj  =  xadj_dist .asarray (masked = True )
97-     xadj  =  xadj .reshape (p_prime , K , blk_cols_X )
106+     xadj_blocks  =  []
107+     offset  =  0 
108+     for  cnt  in  col_counts :
109+         block_size  =  K  *  cnt 
110+         xadj_blocks .append (
111+             xadj [offset : offset  +  block_size ].reshape (K , cnt )
112+         )
113+         offset  +=  block_size 
114+     xadj  =  np .hstack (xadj_blocks )
98115
99116    if  rank  ==  0 :
100-         y_loc  =  ( A_glob  @ X_glob ). squeeze () 
117+         y_loc  =  A_glob  @ X_glob 
101118        assert_allclose (
102-             y ,
103-             y_loc ,
119+             y . squeeze () ,
120+             y_loc . squeeze () ,
104121            rtol = np .finfo (np .dtype (dtype )).resolution ,
105122            err_msg = f"Rank { rank }  : Forward verification failed." 
106123        )
107124
108-         xadj_loc  =  ( A_glob .conj ().T  @ y_loc . conj ()). conj (). squeeze () 
125+         xadj_loc  =  A_glob .conj ().T  @ y_loc 
109126        assert_allclose (
110-             xadj ,
111-             xadj_loc ,
127+             xadj . squeeze () ,
128+             xadj_loc . squeeze () ,
112129            rtol = np .finfo (np .dtype (dtype )).resolution ,
113130            err_msg = f"Rank { rank }  : Ajoint verification failed." 
114131        )
0 commit comments