@@ -163,37 +163,12 @@ def __init__(
163163 shape = (int (np .prod (self .dimsd )), int (np .prod (self .dims )))
164164 super ().__init__ (shape = shape , dtype = np .dtype (dtype ), base_comm = base_comm )
165165
166- def _matvec (self , x : DistributedArray ) -> DistributedArray :
167- ncp = get_module (x .engine )
168- if x .partition != Partition .SCATTER :
169- raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
170-
171- y = DistributedArray (
172- global_shape = (self .N * self .dimsd [1 ]),
173- local_shapes = [(self .N * c ) for c in self ._rank_col_lens ],
174- mask = x .mask ,
175- partition = Partition .SCATTER ,
176- dtype = self .dtype ,
177- base_comm = self .base_comm
178- )
179-
180- my_own_cols = self ._rank_col_lens [self .rank ]
181- x_arr = x .local_array .reshape ((self .dims [0 ], my_own_cols ))
182- X_local = x_arr .astype (self .dtype )
183- Y_local = ncp .vstack (
184- self ._row_comm .allgather (
185- ncp .matmul (self .A , X_local )
186- )
187- )
188- y [:] = Y_local .flatten ()
189- return y
190-
191166 @staticmethod
192167 def active_grid_comm (base_comm : MPI .Comm , N : int , M : int ):
193168 r"""Configure active grid
194169
195170 Configure a square process grid from a parent MPI communicator and
196- select the subset of "active" processes. Each process in ``base_comm``
171+ select a subset of "active" processes. Each process in ``base_comm``
197172 is assigned to a logical 2D grid of size :math:`P' \times P'`,
198173 where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first
199174 :math:`active_dim x active_dim` processes
@@ -218,7 +193,7 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
218193 if inactive).
219194 row : :obj:`int`
220195 Grid row index of this process in the active grid (or original rank
221- if inactive).
196+ if inactive).
222197 col : :obj:`int`
223198 Grid column index of this process in the active grid
224199 (or original rank if inactive).
@@ -246,6 +221,31 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
246221
247222 return new_comm , new_rank , new_row , new_col , True
248223
224+ def _matvec (self , x : DistributedArray ) -> DistributedArray :
225+ ncp = get_module (x .engine )
226+ if x .partition != Partition .SCATTER :
227+ raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
228+
229+ y = DistributedArray (
230+ global_shape = (self .N * self .dimsd [1 ]),
231+ local_shapes = [(self .N * c ) for c in self ._rank_col_lens ],
232+ mask = x .mask ,
233+ partition = Partition .SCATTER ,
234+ dtype = self .dtype ,
235+ base_comm = self .base_comm
236+ )
237+
238+ my_own_cols = self ._rank_col_lens [self .rank ]
239+ x_arr = x .local_array .reshape ((self .dims [0 ], my_own_cols ))
240+ X_local = x_arr .astype (self .dtype )
241+ Y_local = ncp .vstack (
242+ self ._row_comm .allgather (
243+ ncp .matmul (self .A , X_local )
244+ )
245+ )
246+ y [:] = Y_local .flatten ()
247+ return y
248+
249249 def _rmatvec (self , x : DistributedArray ) -> DistributedArray :
250250 ncp = get_module (x .engine )
251251 if x .partition != Partition .SCATTER :
0 commit comments