@@ -166,6 +166,9 @@ def _matvec(self, x):
166166 if x .partition != Partition .SCATTER :
167167 raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
168168
169+ y = DistributedArray (global_shape = self .shape [0 ],
170+ partition = Partition .SCATTER )
171+
169172 core = x .local_array .reshape (self .local_dims )
170173 halo_arr = ncp .zeros (self .local_extent , dtype = self .dtype )
171174 # insert core
@@ -181,18 +184,18 @@ def _matvec(self, x):
181184 for ax in range (self .ndim ):
182185 self ._apply_bc_along_axis (ncp , halo_arr , axis = ax )
183186 # pack result
184- res = DistributedArray (global_shape = self .shape [0 ],
185- partition = Partition .SCATTER )
186- res [:] = halo_arr .ravel ()
187- return res
187+ y [:] = halo_arr .ravel ()
188+ return y
188189
189190 def _rmatvec (self , x ):
190191 if x .partition != Partition .SCATTER :
191192 raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
192- res = DistributedArray (global_shape = self .shape [1 ],
193+
194+ y = DistributedArray (global_shape = self .shape [1 ],
193195 partition = Partition .SCATTER )
196+
194197 arr = x .local_array .reshape (self .local_extent )
195198 core_slices = [slice (left , left + ldim ) for left , ldim in zip (self .halo [::2 ], self .local_dims )]
196199 core = arr [tuple (core_slices )]
197- res [:] = core .ravel ()
198- return res
200+ y [:] = core .ravel ()
201+ return y
0 commit comments