@@ -74,7 +74,7 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
7474def  local_block_spit (global_shape : Tuple [int , int ],
7575                     rank : int ,
7676                     comm : MPI .Comm ) ->  Tuple [slice , slice ]:
77-     """ 
77+     r """
7878    Compute the local sub‐block of a 2D global array for a process in a square process grid. 
7979
8080    Parameters 
@@ -122,7 +122,7 @@ def local_block_spit(global_shape: Tuple[int, int],
122122
123123
124124def  block_gather (x : DistributedArray , new_shape : Tuple [int , int ], orig_shape : Tuple [int , int ], comm : MPI .Comm ):
125-     """ 
125+     r """
126126    Gather distributed local blocks from 2D block distributed matrix distributed 
127127    amongst a square process grid into the full global array. 
128128
@@ -351,19 +351,19 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
351351        ncp  =  get_module (x .engine )
352352        if  x .partition  !=  Partition .SCATTER :
353353            raise  ValueError (f"x should have partition={ Partition .SCATTER } { x .partition }  )
354- 
354+          output_dtype   =   np . result_type ( self . dtype ,  x . dtype ) 
355355        y  =  DistributedArray (
356356            global_shape = (self .N  *  self .dimsd [1 ]),
357357            local_shapes = [(self .N  *  c ) for  c  in  self ._rank_col_lens ],
358358            mask = x .mask ,
359359            partition = Partition .SCATTER ,
360-             dtype = self . dtype ,
360+             dtype = output_dtype ,
361361            base_comm = self .base_comm 
362362        )
363363
364364        my_own_cols  =  self ._rank_col_lens [self .rank ]
365365        x_arr  =  x .local_array .reshape ((self .dims [0 ], my_own_cols ))
366-         X_local  =  x_arr .astype (self . dtype )
366+         X_local  =  x_arr .astype (output_dtype )
367367        Y_local  =  ncp .vstack (
368368            self ._row_comm .allgather (
369369                ncp .matmul (self .A , X_local )
@@ -377,16 +377,28 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
377377        if  x .partition  !=  Partition .SCATTER :
378378            raise  ValueError (f"x should have partition={ Partition .SCATTER } { x .partition }  )
379379
380+         # - If A is real: A^H = A^T, 
381+         #       so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real) 
382+         # - If A is complex: A^H is complex, 
383+         #       so result will be complex regardless of x 
384+         if  np .iscomplexobj (self .A ):
385+             output_dtype  =  np .result_type (self .dtype , x .dtype )
386+         else :
387+             # Real matrix: A^T @ x preserves input type complexity 
388+             output_dtype  =  x .dtype  if  np .iscomplexobj (x .local_array ) else  self .dtype 
389+             # But still need to check type promotion for precision 
390+             output_dtype  =  np .result_type (self .dtype , output_dtype )
391+ 
380392        y  =  DistributedArray (
381393            global_shape = (self .K  *  self .dimsd [1 ]),
382394            local_shapes = [self .K  *  c  for  c  in  self ._rank_col_lens ],
383395            mask = x .mask ,
384396            partition = Partition .SCATTER ,
385-             dtype = self . dtype ,
397+             dtype = output_dtype ,
386398            base_comm = self .base_comm 
387399        )
388400
389-         x_arr  =  x .local_array .reshape ((self .N , self ._local_ncols )).astype (self . dtype )
401+         x_arr  =  x .local_array .reshape ((self .N , self ._local_ncols )).astype (output_dtype )
390402        X_tile  =  x_arr [self ._row_start :self ._row_end , :]
391403        A_local  =  self .At  if  hasattr (self , "At" ) else  self .A .T .conj ()
392404        Y_local  =  ncp .matmul (A_local , X_tile )
@@ -536,7 +548,6 @@ def __init__(
536548        self ._col_comm  =  base_comm .Split (color = self ._col_id , key = self ._row_id )
537549
538550        self .A  =  A .astype (np .dtype (dtype ))
539-         if  saveAt : self .At  =  A .T .conj ()
540551
541552        self .N  =  self ._col_comm .allreduce (A .shape [0 ])
542553        self .K  =  self ._row_comm .allreduce (A .shape [1 ])
@@ -569,6 +580,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
569580        if  x .partition  !=  Partition .SCATTER :
570581            raise  ValueError (f"x should have partition={ Partition .SCATTER } { x .partition }  )
571582
583+         output_dtype  =  np .result_type (self .dtype , x .dtype )
572584        # Calculate local shapes for block distribution 
573585        bn  =  self ._N_padded  //  self ._P_prime   # block size in N dimension 
574586        bm  =  self ._M_padded  //  self ._P_prime   # block size in M dimension 
@@ -582,9 +594,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
582594                             mask = x .mask ,
583595                             local_shapes = local_shapes ,
584596                             partition = Partition .SCATTER ,
585-                              dtype = self .dtype ,
586-                              base_comm = self .base_comm 
587-                              )
597+                              dtype = output_dtype ,
598+                              base_comm = self .base_comm )
588599
589600        # Calculate expected padded dimensions for x 
590601        bk  =  self ._K_padded  //  self ._P_prime   # block size in K dimension 
@@ -603,13 +614,13 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
603614        if  pad_k  >  0  or  pad_m  >  0 :
604615            x_block  =  np .pad (x_block , [(0 , pad_k ), (0 , pad_m )], mode = 'constant' )
605616
606-         Y_local  =  np .zeros ((self .A .shape [0 ], bm ))
617+         Y_local  =  np .zeros ((self .A .shape [0 ], bm ), dtype = output_dtype )
607618
608619        for  k  in  range (self ._P_prime ):
609620            Atemp  =  self .A .copy () if  self ._col_id  ==  k  else  np .empty_like (self .A )
610621            Xtemp  =  x_block .copy () if  self ._row_id  ==  k  else  np .empty_like (x_block )
611-             self ._row_comm .bcast (Atemp , root = k )
612-             self ._col_comm .bcast (Xtemp , root = k )
622+             self ._row_comm .Bcast (Atemp , root = k )
623+             self ._col_comm .Bcast (Xtemp , root = k )
613624            Y_local  +=  ncp .dot (Atemp , Xtemp )
614625
615626        Y_local_unpadded  =  Y_local [:local_n , :local_m ]
@@ -631,13 +642,24 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
631642        local_m  =  bm  if  self ._col_id  !=  self ._P_prime  -  1  else  self .M  -  (self ._P_prime  -  1 ) *  bm 
632643
633644        local_shapes  =  self .base_comm .allgather (local_k  *  local_m )
645+         # - If A is real: A^H = A^T, 
646+         #       so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real) 
647+         # - If A is complex: A^H is complex, 
648+         #       so result will be complex regardless of x 
649+         if  np .iscomplexobj (self .A ):
650+             output_dtype  =  np .result_type (self .dtype , x .dtype )
651+         else :
652+             # Real matrix: A^T @ x preserves input type complexity 
653+             output_dtype  =  x .dtype  if  np .iscomplexobj (x .local_array ) else  self .dtype 
654+             # But still need to check type promotion for precision 
655+             output_dtype  =  np .result_type (self .dtype , output_dtype )
634656
635657        y  =  DistributedArray (
636658            global_shape = (self .K  *  self .M ),
637659            mask = x .mask ,
638660            local_shapes = local_shapes ,
639661            partition = Partition .SCATTER ,
640-             dtype = self . dtype ,
662+             dtype = output_dtype ,
641663            base_comm = self .base_comm 
642664        )
643665
@@ -659,7 +681,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
659681            x_block  =  np .pad (x_block , [(0 , pad_n ), (0 , pad_m )], mode = 'constant' )
660682
661683        A_local  =  self .At  if  hasattr (self , "At" ) else  self .A .T .conj ()
662-         Y_local  =  np .zeros ((self .A .shape [1 ], bm ))
684+         Y_local  =  np .zeros ((self .A .shape [1 ], bm ),  dtype = output_dtype )
663685
664686        for  k  in  range (self ._P_prime ):
665687            requests  =  []
0 commit comments