@@ -82,7 +82,7 @@ def local_block_split(global_shape: Tuple[int, int],
8282                      comm : MPI .Comm ) ->  Tuple [slice , slice ]:
8383    r"""Local sub‐block of a 2D global array 
8484
85-     Compute the local sub‐block of a 2D global array for a process in a square   
85+     Compute the local sub‐block of a 2D global array for a process in a square 
8686    process grid. 
8787
8888    Parameters 
@@ -106,9 +106,8 @@ def local_block_split(global_shape: Tuple[int, int],
106106    ValueError 
107107        If `rank` is not an integer value or out of range. 
108108    RuntimeError 
109-         If the number of processes participating in the provided communicator   
109+         If the number of processes participating in the provided communicator 
110110        is not a perfect square. 
111-      
112111    """ 
113112    size  =  comm .Get_size ()
114113    p_prime  =  math .isqrt (size )
@@ -130,7 +129,7 @@ def local_block_split(global_shape: Tuple[int, int],
130129
131130def  block_gather (x : DistributedArray , orig_shape : Tuple [int , int ], comm : MPI .Comm ):
132131    r"""Local block from 2D block distributed matrix 
133-      
132+ 
134133    Gather distributed local blocks from 2D block distributed matrix distributed 
135134    amongst a square process grid into the full global array. 
136135
@@ -152,9 +151,8 @@ def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Com
152151    Raises 
153152    ------ 
154153    RuntimeError 
155-         If the number of processes participating in the provided communicator   
154+         If the number of processes participating in the provided communicator 
156155        is not a perfect square. 
157-      
158156    """ 
159157    ncp  =  get_module (x .engine )
160158    p_prime  =  math .isqrt (comm .Get_size ())
@@ -169,7 +167,7 @@ def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Com
169167        pr , pc  =  divmod (rank , p_prime )
170168        rs , cs  =  pr  *  br , pc  *  bc 
171169        re , ce  =  min (rs  +  br , nr ), min (cs  +  bc , nc )
172-         if  len (all_blks [rank ]) != 0 :
170+         if  len (all_blks [rank ]) !=   0 :
173171            C [rs :re , cs :ce ] =  all_blks [rank ].reshape (re  -  rs , cs  -  ce )
174172    return  C 
175173
@@ -519,11 +517,11 @@ def __init__(
519517        size  =  base_comm .Get_size ()
520518
521519        # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size 
522-         self ._P_prime  =    math .isqrt (size )
520+         self ._P_prime  =  math .isqrt (size )
523521        if  self ._P_prime  *  self ._P_prime  !=  size :
524522            raise  Exception (f"Number of processes must be a square number, provided { size }  )
525523
526-         self ._row_id , self ._col_id  =    divmod (rank , self ._P_prime )
524+         self ._row_id , self ._col_id  =  divmod (rank , self ._P_prime )
527525
528526        self .base_comm  =  base_comm 
529527        self ._row_comm  =  base_comm .Split (color = self ._row_id , key = self ._col_id )
@@ -541,7 +539,7 @@ def __init__(
541539
542540        bn  =  self ._N_padded  //  self ._P_prime 
543541        bk  =  self ._K_padded  //  self ._P_prime 
544-         bm  =  self ._M_padded  //  self ._P_prime 
542+         bm  =  self ._M_padded  //  self ._P_prime    # noqa: F841 
545543
546544        pr  =  (bn  -  A .shape [0 ]) if  self ._row_id  ==  self ._P_prime  -  1  else  0 
547545        pc  =  (bk  -  A .shape [1 ]) if  self ._col_id  ==  self ._P_prime  -  1  else  0 
@@ -552,7 +550,7 @@ def __init__(
552550        if  saveAt :
553551            self .At  =  self .A .T .conj ()
554552
555-         self .dims    =  (self .K , self .M )
553+         self .dims  =  (self .K , self .M )
556554        self .dimsd  =  (self .N , self .M )
557555        shape  =  (int (np .prod (self .dimsd )), int (np .prod (self .dims )))
558556        super ().__init__ (shape = shape , dtype = np .dtype (dtype ), base_comm = base_comm )
@@ -597,7 +595,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
597595        if  pad_k  >  0  or  pad_m  >  0 :
598596            x_block  =  ncp .pad (x_block , [(0 , pad_k ), (0 , pad_m )], mode = 'constant' )
599597
600-         Y_local  =  ncp .zeros ((self .A .shape [0 ], bm ),dtype = output_dtype )
598+         Y_local  =  ncp .zeros ((self .A .shape [0 ], bm ),  dtype = output_dtype )
601599
602600        for  k  in  range (self ._P_prime ):
603601            Atemp  =  self .A .copy () if  self ._col_id  ==  k  else  ncp .empty_like (self .A )
@@ -690,19 +688,18 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
690688
691689
692690def  MPIMatrixMult (
693-             A : NDArray ,
694-             M : int ,
695-             saveAt : bool  =  False ,
696-             base_comm : MPI .Comm  =  MPI .COMM_WORLD ,
697-             kind : Literal ["summa" , "block" ] =  "summa" ,
698-             dtype : DTypeLike  =  "float64" ,
699-     ):
691+         A : NDArray ,
692+         M : int ,
693+         saveAt : bool  =  False ,
694+         base_comm : MPI .Comm  =  MPI .COMM_WORLD ,
695+         kind : Literal ["summa" , "block" ] =  "summa" ,
696+         dtype : DTypeLike  =  "float64" ):
700697    r""" 
701698    MPI Distributed Matrix Multiplication Operator 
702699
703700    This operator performs distributed matrix-matrix multiplication 
704-     using either the SUMMA (Scalable Universal Matrix Multiplication   
705-     Algorithm [1]_) or a 1D block-row decomposition algorithm (based on the   
701+     using either the SUMMA (Scalable Universal Matrix Multiplication 
702+     Algorithm [1]_) or a 1D block-row decomposition algorithm (based on the 
706703    specified ``kind`` parameter). 
707704
708705    Parameters 
@@ -712,7 +709,7 @@ def MPIMatrixMult(
712709    M : :obj:`int` 
713710        Global number of columns in the operand and result matrices. 
714711    saveAt : :obj:`bool`, optional 
715-         If ``True``, store both :math:`\mathbf{A}` and its conjugate transpose   
712+         If ``True``, store both :math:`\mathbf{A}` and its conjugate transpose 
716713        :math:`\mathbf{A}^H` to accelerate adjoint operations (uses twice the 
717714        memory). Default is ``False``. 
718715    base_comm : :obj:`mpi4py.MPI.Comm`, optional 
@@ -729,8 +726,7 @@ def MPIMatrixMult(
729726    shape : :obj:`tuple` 
730727        Operator shape 
731728    kind : :obj:`str`, optional 
732-         Selected distributed matrix multiply algorithm (``'block'`` or  
733-         ``'summa'``). 
729+         Selected distributed matrix multiply algorithm (``'block'`` or ``'summa'``). 
734730
735731    Raises 
736732    ------ 
@@ -739,7 +735,7 @@ def MPIMatrixMult(
739735    Exception 
740736        If the MPI communicator does not form a compatible grid for the 
741737        selected algorithm. 
742-      
738+ 
743739    Notes 
744740    ----- 
745741    The forward operator computes: 
@@ -762,28 +758,28 @@ def MPIMatrixMult(
762758
763759    Based on the choice of ``kind``, the distribution layouts of the operator and model and 
764760    data vectors differ as follows: 
765-      
761+ 
766762    :summa: 
767763
768764    2D block-grid distribution over a square process grid  :math:`[\sqrt{P} \times \sqrt{P}]`: 
769765
770-     - :math:`\mathbf{A}` and :math:`\mathbf{X}` (and  :math:`\mathbf{Y}`) are partitioned into   
771-       :math:`[N_{loc} \times K_{loc}]` and :math:`[K_{loc} \times M_{loc}]` tiles on each   
766+     - :math:`\mathbf{A}` and :math:`\mathbf{X}` (and  :math:`\mathbf{Y}`) are partitioned into 
767+       :math:`[N_{loc} \times K_{loc}]` and :math:`[K_{loc} \times M_{loc}]` tiles on each 
772768      rank, respectively. 
773769    - Each SUMMA iteration broadcasts row- and column-blocks of :math:`\mathbf{A}` and 
774-       :math:`\mathbf{X}` (forward) or :math:`\mathbf{Y}` (adjoint) and accumulates local   
770+       :math:`\mathbf{X}` (forward) or :math:`\mathbf{Y}` (adjoint) and accumulates local 
775771      partial products. 
776772
777773    :block: 
778-      
774+ 
779775    1D block-row distribution over a :math:`[1 \times P]` grid: 
780776
781777    - :math:`\mathbf{A}` is partitioned into :math:`[N_{loc} \times K]` blocks across ranks. 
782778    - :math:`\mathbf{X}` (and  :math:`\mathbf{Y}`) are partitioned into :math:`[K \times M_{loc}]` blocks. 
783779    - Local multiplication is followed by row-wise gather (forward) or 
784780      allreduce (adjoint) across ranks. 
785781
786-     .. [1] Robert A. van de Geijn, R., and Watts, J. "SUMMA: Scalable Universal   
782+     .. [1] Robert A. van de Geijn, R., and Watts, J. "SUMMA: Scalable Universal 
787783       Matrix Multiplication Algorithm", 1995. 
788784
789785    """ 
0 commit comments