@@ -141,7 +141,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
141141 @reshaped
142142 def _matvec_forward (self , x : DistributedArray ) -> DistributedArray :
143143 ncp = get_module (x .engine )
144- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
144+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
145145 axis = x .axis , engine = x .engine , dtype = self .dtype )
146146 ghosted_x = x .add_ghost_cells (cells_back = 1 )
147147 y_forward = ghosted_x [1 :] - ghosted_x [:- 1 ]
@@ -153,7 +153,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
153153 @reshaped
154154 def _rmatvec_forward (self , x : DistributedArray ) -> DistributedArray :
155155 ncp = get_module (x .engine )
156- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
156+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
157157 axis = x .axis , engine = x .engine , dtype = self .dtype )
158158 y [:] = 0
159159 if self .rank == self .size - 1 :
@@ -171,8 +171,8 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
171171 @reshaped
172172 def _matvec_backward (self , x : DistributedArray ) -> DistributedArray :
173173 ncp = get_module (x .engine )
174- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
175- axis = x .axis , engine = x .engine , dtype = self .dtype )
174+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
175+ axis = x .axis , engine = x .engine , dtype = self .dtype )
176176 ghosted_x = x .add_ghost_cells (cells_front = 1 )
177177 y_backward = ghosted_x [1 :] - ghosted_x [:- 1 ]
178178 if self .rank == 0 :
@@ -183,7 +183,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
183183 @reshaped
184184 def _rmatvec_backward (self , x : DistributedArray ) -> DistributedArray :
185185 ncp = get_module (x .engine )
186- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
186+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
187187 axis = x .axis , engine = x .engine , dtype = self .dtype )
188188 y [:] = 0
189189 ghosted_x = x .add_ghost_cells (cells_back = 1 )
@@ -201,7 +201,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
201201 @reshaped
202202 def _matvec_centered3 (self , x : DistributedArray ) -> DistributedArray :
203203 ncp = get_module (x .engine )
204- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
204+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
205205 axis = x .axis , engine = x .engine , dtype = self .dtype )
206206 ghosted_x = x .add_ghost_cells (cells_front = 1 , cells_back = 1 )
207207 y_centered = 0.5 * (ghosted_x [2 :] - ghosted_x [:- 2 ])
@@ -221,7 +221,7 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
221221 @reshaped
222222 def _rmatvec_centered3 (self , x : DistributedArray ) -> DistributedArray :
223223 ncp = get_module (x .engine )
224- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
224+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
225225 axis = x .axis , engine = x .engine , dtype = self .dtype )
226226 y [:] = 0
227227
@@ -249,7 +249,7 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
249249 @reshaped
250250 def _matvec_centered5 (self , x : DistributedArray ) -> DistributedArray :
251251 ncp = get_module (x .engine )
252- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
252+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
253253 axis = x .axis , engine = x .engine , dtype = self .dtype )
254254 ghosted_x = x .add_ghost_cells (cells_front = 2 , cells_back = 2 )
255255 y_centered = (
@@ -276,7 +276,7 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
276276 @reshaped
277277 def _rmatvec_centered5 (self , x : DistributedArray ) -> DistributedArray :
278278 ncp = get_module (x .engine )
279- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
279+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
280280 axis = x .axis , engine = x .engine , dtype = self .dtype )
281281 y [:] = 0
282282 ghosted_x = x .add_ghost_cells (cells_back = 4 )
0 commit comments