Skip to content

Commit 5155186

Browse files
committed
Fix flake issue
1 parent 7939587 commit 5155186

File tree

1 file changed

+27
-31
lines changed

1 file changed

+27
-31
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def local_block_split(global_shape: Tuple[int, int],
8282
comm: MPI.Comm) -> Tuple[slice, slice]:
8383
r"""Local sub‐block of a 2D global array
8484
85-
Compute the local sub‐block of a 2D global array for a process in a square
85+
Compute the local sub‐block of a 2D global array for a process in a square
8686
process grid.
8787
8888
Parameters
@@ -106,9 +106,8 @@ def local_block_split(global_shape: Tuple[int, int],
106106
ValueError
107107
If `rank` is not an integer value or out of range.
108108
RuntimeError
109-
If the number of processes participating in the provided communicator
109+
If the number of processes participating in the provided communicator
110110
is not a perfect square.
111-
112111
"""
113112
size = comm.Get_size()
114113
p_prime = math.isqrt(size)
@@ -130,7 +129,7 @@ def local_block_split(global_shape: Tuple[int, int],
130129

131130
def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Comm):
132131
r"""Local block from 2D block distributed matrix
133-
132+
134133
Gather distributed local blocks from 2D block distributed matrix distributed
135134
amongst a square process grid into the full global array.
136135
@@ -152,9 +151,8 @@ def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Com
152151
Raises
153152
------
154153
RuntimeError
155-
If the number of processes participating in the provided communicator
154+
If the number of processes participating in the provided communicator
156155
is not a perfect square.
157-
158156
"""
159157
ncp = get_module(x.engine)
160158
p_prime = math.isqrt(comm.Get_size())
@@ -169,7 +167,7 @@ def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Com
169167
pr, pc = divmod(rank, p_prime)
170168
rs, cs = pr * br, pc * bc
171169
re, ce = min(rs + br, nr), min(cs + bc, nc)
172-
if len(all_blks[rank]) !=0:
170+
if len(all_blks[rank]) != 0:
173171
C[rs:re, cs:ce] = all_blks[rank].reshape(re - rs, cs - ce)
174172
return C
175173

@@ -519,11 +517,11 @@ def __init__(
519517
size = base_comm.Get_size()
520518

521519
# Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
522-
self._P_prime = math.isqrt(size)
520+
self._P_prime = math.isqrt(size)
523521
if self._P_prime * self._P_prime != size:
524522
raise Exception(f"Number of processes must be a square number, provided {size} instead...")
525523

526-
self._row_id, self._col_id = divmod(rank, self._P_prime)
524+
self._row_id, self._col_id = divmod(rank, self._P_prime)
527525

528526
self.base_comm = base_comm
529527
self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id)
@@ -541,7 +539,7 @@ def __init__(
541539

542540
bn = self._N_padded // self._P_prime
543541
bk = self._K_padded // self._P_prime
544-
bm = self._M_padded // self._P_prime
542+
bm = self._M_padded // self._P_prime # noqa: F841
545543

546544
pr = (bn - A.shape[0]) if self._row_id == self._P_prime - 1 else 0
547545
pc = (bk - A.shape[1]) if self._col_id == self._P_prime - 1 else 0
@@ -552,7 +550,7 @@ def __init__(
552550
if saveAt:
553551
self.At = self.A.T.conj()
554552

555-
self.dims = (self.K, self.M)
553+
self.dims = (self.K, self.M)
556554
self.dimsd = (self.N, self.M)
557555
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
558556
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
@@ -597,7 +595,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
597595
if pad_k > 0 or pad_m > 0:
598596
x_block = ncp.pad(x_block, [(0, pad_k), (0, pad_m)], mode='constant')
599597

600-
Y_local = ncp.zeros((self.A.shape[0], bm),dtype=output_dtype)
598+
Y_local = ncp.zeros((self.A.shape[0], bm), dtype=output_dtype)
601599

602600
for k in range(self._P_prime):
603601
Atemp = self.A.copy() if self._col_id == k else ncp.empty_like(self.A)
@@ -690,19 +688,18 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
690688

691689

692690
def MPIMatrixMult(
693-
A: NDArray,
694-
M: int,
695-
saveAt: bool = False,
696-
base_comm: MPI.Comm = MPI.COMM_WORLD,
697-
kind: Literal["summa", "block"] = "summa",
698-
dtype: DTypeLike = "float64",
699-
):
691+
A: NDArray,
692+
M: int,
693+
saveAt: bool = False,
694+
base_comm: MPI.Comm = MPI.COMM_WORLD,
695+
kind: Literal["summa", "block"] = "summa",
696+
dtype: DTypeLike = "float64"):
700697
r"""
701698
MPI Distributed Matrix Multiplication Operator
702699
703700
This operator performs distributed matrix-matrix multiplication
704-
using either the SUMMA (Scalable Universal Matrix Multiplication
705-
Algorithm [1]_) or a 1D block-row decomposition algorithm (based on the
701+
using either the SUMMA (Scalable Universal Matrix Multiplication
702+
Algorithm [1]_) or a 1D block-row decomposition algorithm (based on the
706703
specified ``kind`` parameter).
707704
708705
Parameters
@@ -712,7 +709,7 @@ def MPIMatrixMult(
712709
M : :obj:`int`
713710
Global number of columns in the operand and result matrices.
714711
saveAt : :obj:`bool`, optional
715-
If ``True``, store both :math:`\mathbf{A}` and its conjugate transpose
712+
If ``True``, store both :math:`\mathbf{A}` and its conjugate transpose
716713
:math:`\mathbf{A}^H` to accelerate adjoint operations (uses twice the
717714
memory). Default is ``False``.
718715
base_comm : :obj:`mpi4py.MPI.Comm`, optional
@@ -729,8 +726,7 @@ def MPIMatrixMult(
729726
shape : :obj:`tuple`
730727
Operator shape
731728
kind : :obj:`str`, optional
732-
Selected distributed matrix multiply algorithm (``'block'`` or
733-
``'summa'``).
729+
Selected distributed matrix multiply algorithm (``'block'`` or ``'summa'``).
734730
735731
Raises
736732
------
@@ -739,7 +735,7 @@ def MPIMatrixMult(
739735
Exception
740736
If the MPI communicator does not form a compatible grid for the
741737
selected algorithm.
742-
738+
743739
Notes
744740
-----
745741
The forward operator computes:
@@ -762,28 +758,28 @@ def MPIMatrixMult(
762758
763759
Based on the choice of ``kind``, the distribution layouts of the operator and model and
764760
data vectors differ as follows:
765-
761+
766762
:summa:
767763
768764
2D block-grid distribution over a square process grid :math:`[\sqrt{P} \times \sqrt{P}]`:
769765
770-
- :math:`\mathbf{A}` and :math:`\mathbf{X}` (and :math:`\mathbf{Y}`) are partitioned into
771-
:math:`[N_{loc} \times K_{loc}]` and :math:`[K_{loc} \times M_{loc}]` tiles on each
766+
- :math:`\mathbf{A}` and :math:`\mathbf{X}` (and :math:`\mathbf{Y}`) are partitioned into
767+
:math:`[N_{loc} \times K_{loc}]` and :math:`[K_{loc} \times M_{loc}]` tiles on each
772768
rank, respectively.
773769
- Each SUMMA iteration broadcasts row- and column-blocks of :math:`\mathbf{A}` and
774-
:math:`\mathbf{X}` (forward) or :math:`\mathbf{Y}` (adjoint) and accumulates local
770+
:math:`\mathbf{X}` (forward) or :math:`\mathbf{Y}` (adjoint) and accumulates local
775771
partial products.
776772
777773
:block:
778-
774+
779775
1D block-row distribution over a :math:`[1 \times P]` grid:
780776
781777
- :math:`\mathbf{A}` is partitioned into :math:`[N_{loc} \times K]` blocks across ranks.
782778
- :math:`\mathbf{X}` (and :math:`\mathbf{Y}`) are partitioned into :math:`[K \times M_{loc}]` blocks.
783779
- Local multiplication is followed by row-wise gather (forward) or
784780
allreduce (adjoint) across ranks.
785781
786-
.. [1] Robert A. van de Geijn, R., and Watts, J. "SUMMA: Scalable Universal
782+
.. [1] Robert A. van de Geijn, R., and Watts, J. "SUMMA: Scalable Universal
787783
Matrix Multiplication Algorithm", 1995.
788784
789785
"""

0 commit comments

Comments
 (0)