Skip to content

Commit 13025d7

Browse files
committed
minor: moved active_grid_comm before _matvec
1 parent 286ab22 commit 13025d7

File tree

1 file changed

+27
-27
lines changed

1 file changed

+27
-27
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -163,37 +163,12 @@ def __init__(
163163
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
164164
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
165165

166-
def _matvec(self, x: DistributedArray) -> DistributedArray:
167-
ncp = get_module(x.engine)
168-
if x.partition != Partition.SCATTER:
169-
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
170-
171-
y = DistributedArray(
172-
global_shape=(self.N * self.dimsd[1]),
173-
local_shapes=[(self.N * c) for c in self._rank_col_lens],
174-
mask=x.mask,
175-
partition=Partition.SCATTER,
176-
dtype=self.dtype,
177-
base_comm=self.base_comm
178-
)
179-
180-
my_own_cols = self._rank_col_lens[self.rank]
181-
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
182-
X_local = x_arr.astype(self.dtype)
183-
Y_local = ncp.vstack(
184-
self._row_comm.allgather(
185-
ncp.matmul(self.A, X_local)
186-
)
187-
)
188-
y[:] = Y_local.flatten()
189-
return y
190-
191166
@staticmethod
192167
def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
193168
r"""Configure active grid
194169
195170
Configure a square process grid from a parent MPI communicator and
196-
select the subset of "active" processes. Each process in ``base_comm``
171+
select a subset of "active" processes. Each process in ``base_comm``
197172
is assigned to a logical 2D grid of size :math:`P' \times P'`,
198173
where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first
199174
:math:`active_dim x active_dim` processes
@@ -218,7 +193,7 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
218193
if inactive).
219194
row : :obj:`int`
220195
Grid row index of this process in the active grid (or original rank
221-
if inactive).
196+
if inactive).
222197
col : :obj:`int`
223198
Grid column index of this process in the active grid
224199
(or original rank if inactive).
@@ -246,6 +221,31 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
246221

247222
return new_comm, new_rank, new_row, new_col, True
248223

224+
def _matvec(self, x: DistributedArray) -> DistributedArray:
225+
ncp = get_module(x.engine)
226+
if x.partition != Partition.SCATTER:
227+
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
228+
229+
y = DistributedArray(
230+
global_shape=(self.N * self.dimsd[1]),
231+
local_shapes=[(self.N * c) for c in self._rank_col_lens],
232+
mask=x.mask,
233+
partition=Partition.SCATTER,
234+
dtype=self.dtype,
235+
base_comm=self.base_comm
236+
)
237+
238+
my_own_cols = self._rank_col_lens[self.rank]
239+
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
240+
X_local = x_arr.astype(self.dtype)
241+
Y_local = ncp.vstack(
242+
self._row_comm.allgather(
243+
ncp.matmul(self.A, X_local)
244+
)
245+
)
246+
y[:] = Y_local.flatten()
247+
return y
248+
249249
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
250250
ncp = get_module(x.engine)
251251
if x.partition != Partition.SCATTER:

0 commit comments

Comments
 (0)