Skip to content

Commit db7cc07

Browse files
committed
Renamed layer and group as to row and col respectively
1 parent a85e75a commit db7cc07

File tree

3 files changed

+38
-39
lines changed

3 files changed

+38
-39
lines changed

examples/plot_matrixmult.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@
7171
# R = \bigl\lceil \tfrac{P}{P'} \bigr\rceil.
7272
#
7373
# Each process is therefore assigned a pair of coordinates
74-
# :math:`(g, l)` within this grid:
74+
# :math:`(r,c)` within this grid:
7575
#
7676
# .. math::
77-
# g = \mathrm{rank} \bmod P',
77+
# r = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor,
7878
# \quad
79-
# l = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor.
79+
# c = \mathrm{rank} \bmod P'.
8080
#
8181
#For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout:
8282
#
@@ -85,19 +85,19 @@
8585
# <div style="text-align: center; font-family: monospace; white-space: pre;">
8686
# ┌────────────┬────────────┐
8787
# │ Rank 0 │ Rank 1 │
88-
# │ (g=0, l=0) │ (g=1, l=0) │
88+
# │ (r=0, c=0) │ (r=0, c=1) │
8989
# ├────────────┼────────────┤
9090
# │ Rank 2 │ Rank 3 │
91-
# │ (g=0, l=1) │ (g=1, l=1) │
91+
# │ (r=1, c=0) │ (r=1, c=1) │
9292
# └────────────┴────────────┘
9393
# </div>
9494

95-
my_group = rank % p_prime
96-
my_layer = rank // p_prime
95+
my_col = rank % p_prime
96+
my_row = rank // p_prime
9797

9898
# Create sub‐communicators
99-
layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer
100-
group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group
99+
row_comm = comm.Split(color=my_row, key=my_col) # all procs in same row
100+
col_comm = comm.Split(color=my_col, key=my_row) # all procs in same col
101101

102102
################################################################################
103103
# At this point we divide the rows and columns of :math:`\mathbf{A}` and
@@ -111,10 +111,10 @@
111111
# <div style="text-align: left; font-family: monospace; white-space: pre;">
112112
# <b>Matrix A (4 x 4):</b>
113113
# ┌─────────────────┐
114-
# │ a11 a12 a13 a14 │ <- Rows 0–1 (Group 0)
114+
# │ a11 a12 a13 a14 │ <- Rows 0–1 (Process Grid Col 0)
115115
# │ a21 a22 a23 a24 │
116116
# ├─────────────────┤
117-
# │ a41 a42 a43 a44 │ <- Rows 2–3 (Group 1)
117+
# │ a41 a42 a43 a44 │ <- Rows 2–3 (Process Grid Col 1)
118118
# │ a51 a52 a53 a54 │
119119
# └─────────────────┘
120120
# </div>
@@ -124,7 +124,7 @@
124124
# <div style="text-align: left; font-family: monospace; white-space: pre;">
125125
# <b>Matrix X (4 x 4):</b>
126126
# ┌─────────┬─────────┐
127-
# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Layer 0), Cols 2–3 (Layer 1)
127+
# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Process Grid Row 0), Cols 2–3 (Process Grid Row 1)
128128
# │ b21 b22 │ b23 b24 │
129129
# │ b31 b32 │ b33 b34 │
130130
# │ b41 b42 │ b43 b44 │
@@ -135,11 +135,11 @@
135135
blk_rows = int(math.ceil(N / p_prime))
136136
blk_cols = int(math.ceil(M / p_prime))
137137

138-
rs = my_group * blk_rows
138+
rs = my_col * blk_rows
139139
re = min(N, rs + blk_rows)
140140
my_own_rows = re - rs
141141

142-
cs = my_layer * blk_cols
142+
cs = my_row * blk_cols
143143
ce = min(M, cs + blk_cols)
144144
my_own_cols = ce - cs
145145

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,16 @@ class MPIMatrixMult(MPILinearOperator):
8282
of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local``
8383
is the number of columns assigned to the current process.
8484
85-
2. **Data Broadcasting**: Within each layer (processes with same ``layer_id``),
86-
the operand data is broadcast from the process whose ``group_id`` matches
87-
the ``layer_id``. This ensures all processes in a layer have access to
88-
the same operand columns.
85+
2. **Data Broadcasting**: Within each row (processes with same ``row_id``),
86+
the operand data is broadcast from the process whose ``col_id`` matches
87+
the ``row_id`` (processes along the diagonal). This ensures all processes
88+
in a row have access to the same operand columns.
8989
9090
3. **Local Computation**: Each process computes ``A_local @ X_local`` where:
9191
- ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``)
9292
- ``X_local`` is the broadcasted operand (shape ``K x M_local``)
9393
94-
4. **Layer Gather**: Results from all processes in each layer are gathered
94+
4. **Row-wise Gather**: Results from all processes in each row are gathered
9595
using ``allgather`` to reconstruct the full result matrix vertically.
9696
9797
**Adjoint Operation step-by-step**
@@ -112,9 +112,9 @@ class MPIMatrixMult(MPILinearOperator):
112112
producing a partial result of shape ``(K, M_local)``.
113113
This computes the local contribution of columns of ``A^H`` to the final result.
114114
115-
3. **Layer Reduction**: Since the full result ``Y = A^H \cdot X`` is the
115+
3. **Row-wise Reduction**: Since the full result ``Y = A^H \cdot X`` is the
116116
sum of contributions from all column blocks of ``A^H``, processes in the
117-
same layer perform an ``allreduce`` sum to combine their partial results.
117+
same rows perform an ``allreduce`` sum to combine their partial results.
118118
This gives the complete ``(K, M_local)`` result for their assigned columns.
119119
120120
"""
@@ -135,29 +135,28 @@ def __init__(
135135
if self._P_prime * self._C != size:
136136
raise Exception(f"Number of processes must be a square number, provided {size} instead...")
137137

138-
# Compute this process's group and layer indices
139-
self._group_id = rank % self._P_prime
140-
self._layer_id = rank // self._P_prime
138+
self._col_id = rank % self._P_prime
139+
self._row_id = rank // self._P_prime
141140

142141
# Split communicators by layer (rows) and by group (columns)
143142
self.base_comm = base_comm
144-
self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id)
145-
self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id)
143+
self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id)
144+
self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id)
146145

147146
self.A = A.astype(np.dtype(dtype))
148147
if saveAt: self.At = A.T.conj()
149148

150-
self.N = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)
149+
self.N = self._row_comm.allreduce(self.A.shape[0], op=MPI.SUM)
151150
self.K = A.shape[1]
152151
self.M = M
153152

154153
block_cols = int(math.ceil(self.M / self._P_prime))
155154
blk_rows = int(math.ceil(self.N / self._P_prime))
156155

157-
self._row_start = self._group_id * blk_rows
156+
self._row_start = self._col_id * blk_rows
158157
self._row_end = min(self.N, self._row_start + blk_rows)
159158

160-
self._col_start = self._layer_id * block_cols
159+
self._col_start = self._row_id * block_cols
161160
self._col_end = min(self.M, self._col_start + block_cols)
162161

163162
self._local_ncols = self._col_end - self._col_start
@@ -184,7 +183,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
184183
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
185184
X_local = x_arr.astype(self.dtype)
186185
Y_local = ncp.vstack(
187-
self._layer_comm.allgather(
186+
self._row_comm.allgather(
188187
ncp.matmul(self.A, X_local)
189188
)
190189
)
@@ -208,6 +207,6 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
208207
X_tile = x_arr[self._row_start:self._row_end, :]
209208
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
210209
Y_local = ncp.matmul(A_local, X_tile)
211-
y_layer = self._layer_comm.allreduce(Y_local, op=MPI.SUM)
210+
y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM)
212211
y[:] = y_layer.flatten()
213212
return y

tests/test_matrixmult.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,20 @@ def test_SUMMAMatrixMult(N, K, M, dtype_str):
4040
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
4141
base_float_dtype = np.float32 if dtype == np.complex64 else np.float64
4242

43-
my_group = rank % p_prime
44-
my_layer = rank // p_prime
43+
my_col = rank % p_prime
44+
my_row = rank // p_prime
4545

4646
# Create sub-communicators
47-
layer_comm = comm.Split(color=my_layer, key=my_group)
48-
group_comm = comm.Split(color=my_group, key=my_layer)
47+
row_comm = comm.Split(color=my_row, key=my_col)
48+
col_comm = comm.Split(color=my_col, key=my_row)
4949

5050
# Calculate local matrix dimensions
5151
blk_rows_A = int(math.ceil(N / p_prime))
52-
row_start_A = my_group * blk_rows_A
52+
row_start_A = my_col * blk_rows_A
5353
row_end_A = min(N, row_start_A + blk_rows_A)
5454

5555
blk_cols_X = int(math.ceil(M / p_prime))
56-
col_start_X = my_layer * blk_cols_X
56+
col_start_X = my_row * blk_cols_X
5757
col_end_X = min(M, col_start_X + blk_cols_X)
5858
local_col_X_len = max(0, col_end_X - col_start_X)
5959

@@ -131,5 +131,5 @@ def test_SUMMAMatrixMult(N, K, M, dtype_str):
131131
err_msg=f"Rank {rank}: Ajoint verification failed."
132132
)
133133

134-
group_comm.Free()
135-
layer_comm.Free()
134+
col_comm.Free()
135+
row_comm.Free()

0 commit comments

Comments
 (0)