22import numpy as np
33from mpi4py import MPI
44
5- from pylops .utils import DTypeLike
6- from pylops .utils .typing import InputDimsLike
5+ from pylops .utils . backend import get_module
6+ from pylops .utils .typing import DTypeLike , InputDimsLike
77from pylops .utils ._internal import _value_or_sized_to_tuple
88
99from pylops_mpi import (
@@ -140,17 +140,21 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
140140
141141 @reshaped
142142 def _matvec_forward (self , x : DistributedArray ) -> DistributedArray :
143- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes , axis = x .axis , dtype = self .dtype )
143+ ncp = get_module (x .engine )
144+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
145+ axis = x .axis , engine = x .engine , dtype = self .dtype )
144146 ghosted_x = x .add_ghost_cells (cells_back = 1 )
145147 y_forward = ghosted_x [1 :] - ghosted_x [:- 1 ]
146148 if self .rank == self .size - 1 :
147- y_forward = np .append (y_forward , np .zeros ((1 ,) + self .dims [1 :]), axis = 0 )
149+ y_forward = ncp .append (y_forward , ncp .zeros ((1 ,) + self .dims [1 :]), axis = 0 )
148150 y [:] = y_forward / self .sampling
149151 return y
150152
151153 @reshaped
152154 def _rmatvec_forward (self , x : DistributedArray ) -> DistributedArray :
153- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes , axis = x .axis , dtype = self .dtype )
155+ ncp = get_module (x .engine )
156+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
157+ axis = x .axis , engine = x .engine , dtype = self .dtype )
154158 y [:] = 0
155159 if self .rank == self .size - 1 :
156160 y [:- 1 ] -= x [:- 1 ]
@@ -159,29 +163,33 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
159163 ghosted_x = x .add_ghost_cells (cells_front = 1 )
160164 y_forward = ghosted_x [:- 1 ]
161165 if self .rank == 0 :
162- y_forward = np . insert ( y_forward , 0 , np .zeros ((1 ,) + self .dims [1 :]), axis = 0 )
166+ y_forward = ncp . append ( ncp .zeros ((1 ,) + self .dims [1 :]), y_forward , axis = 0 )
163167 y [:] += y_forward
164168 y [:] /= self .sampling
165169 return y
166170
167171 @reshaped
168172 def _matvec_backward (self , x : DistributedArray ) -> DistributedArray :
169- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes , axis = x .axis , dtype = self .dtype )
173+ ncp = get_module (x .engine )
174+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
175+ axis = x .axis , engine = x .engine , dtype = self .dtype )
170176 ghosted_x = x .add_ghost_cells (cells_front = 1 )
171177 y_backward = ghosted_x [1 :] - ghosted_x [:- 1 ]
172178 if self .rank == 0 :
173- y_backward = np . insert ( y_backward , 0 , np .zeros ((1 ,) + self .dims [1 :]), axis = 0 )
179+ y_backward = ncp . append ( ncp .zeros ((1 ,) + self .dims [1 :]), y_backward , axis = 0 )
174180 y [:] = y_backward / self .sampling
175181 return y
176182
177183 @reshaped
178184 def _rmatvec_backward (self , x : DistributedArray ) -> DistributedArray :
179- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes , axis = x .axis , dtype = self .dtype )
185+ ncp = get_module (x .engine )
186+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
187+ axis = x .axis , engine = x .engine , dtype = self .dtype )
180188 y [:] = 0
181189 ghosted_x = x .add_ghost_cells (cells_back = 1 )
182190 y_backward = ghosted_x [1 :]
183191 if self .rank == self .size - 1 :
184- y_backward = np .append (y_backward , np .zeros ((1 ,) + self .dims [1 :]), axis = 0 )
192+ y_backward = ncp .append (y_backward , ncp .zeros ((1 ,) + self .dims [1 :]), axis = 0 )
185193 y [:] -= y_backward
186194 if self .rank == 0 :
187195 y [1 :] += x [1 :]
@@ -192,13 +200,15 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
192200
193201 @reshaped
194202 def _matvec_centered3 (self , x : DistributedArray ) -> DistributedArray :
195- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes , axis = x .axis , dtype = self .dtype )
203+ ncp = get_module (x .engine )
204+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
205+ axis = x .axis , engine = x .engine , dtype = self .dtype )
196206 ghosted_x = x .add_ghost_cells (cells_front = 1 , cells_back = 1 )
197207 y_centered = 0.5 * (ghosted_x [2 :] - ghosted_x [:- 2 ])
198208 if self .rank == 0 :
199- y_centered = np . insert ( y_centered , 0 , np .zeros ((1 ,) + self .dims [1 :]), axis = 0 )
209+ y_centered = ncp . append ( ncp .zeros ((1 ,) + self .dims [1 :]), y_centered , axis = 0 )
200210 if self .rank == self .size - 1 :
201- y_centered = np .append (y_centered , np .zeros ((min (y .global_shape [0 ] - 1 , 1 ), ) + self .dims [1 :]), axis = 0 )
211+ y_centered = ncp .append (y_centered , ncp .zeros ((min (y .global_shape [0 ] - 1 , 1 ), ) + self .dims [1 :]), axis = 0 )
202212 y [:] = y_centered
203213 if self .edge :
204214 if self .rank == 0 :
@@ -210,18 +220,21 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
210220
211221 @reshaped
212222 def _rmatvec_centered3 (self , x : DistributedArray ) -> DistributedArray :
213- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes , axis = x .axis , dtype = self .dtype )
223+ ncp = get_module (x .engine )
224+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
225+ axis = x .axis , engine = x .engine , dtype = self .dtype )
214226 y [:] = 0
227+
215228 ghosted_x = x .add_ghost_cells (cells_back = 2 )
216229 y_centered = 0.5 * ghosted_x [1 :- 1 ]
217230 if self .rank == self .size - 1 :
218- y_centered = np .append (y_centered , np .zeros ((min (y .global_shape [0 ], 2 ),) + self .dims [1 :]), axis = 0 )
231+ y_centered = ncp .append (y_centered , ncp .zeros ((min (y .global_shape [0 ], 2 ),) + self .dims [1 :]), axis = 0 )
219232 y [:] -= y_centered
220233
221234 ghosted_x = x .add_ghost_cells (cells_front = 2 )
222235 y_centered = 0.5 * ghosted_x [1 :- 1 ]
223236 if self .rank == 0 :
224- y_centered = np . insert ( y_centered , 0 , np .zeros ((min (y .global_shape [0 ], 2 ),) + self .dims [1 :]), axis = 0 )
237+ y_centered = ncp . append ( ncp .zeros ((min (y .global_shape [0 ], 2 ),) + self .dims [1 :]), y_centered , axis = 0 )
225238 y [:] += y_centered
226239 if self .edge :
227240 if self .rank == 0 :
@@ -235,7 +248,9 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
235248
236249 @reshaped
237250 def _matvec_centered5 (self , x : DistributedArray ) -> DistributedArray :
238- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes , axis = x .axis , dtype = self .dtype )
251+ ncp = get_module (x .engine )
252+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
253+ axis = x .axis , engine = x .engine , dtype = self .dtype )
239254 ghosted_x = x .add_ghost_cells (cells_front = 2 , cells_back = 2 )
240255 y_centered = (
241256 ghosted_x [:- 4 ] / 12.0
@@ -244,9 +259,9 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
244259 - ghosted_x [4 :] / 12.0
245260 )
246261 if self .rank == 0 :
247- y_centered = np . insert ( y_centered , 0 , np .zeros ((min (y .global_shape [0 ], 2 ),) + self .dims [1 :]), axis = 0 )
262+ y_centered = ncp . append ( ncp .zeros ((min (y .global_shape [0 ], 2 ),) + self .dims [1 :]), y_centered , axis = 0 )
248263 if self .rank == self .size - 1 :
249- y_centered = np .append (y_centered , np .zeros ((min (y .global_shape [0 ] - 2 , 2 ),) + self .dims [1 :]), axis = 0 )
264+ y_centered = ncp .append (y_centered , ncp .zeros ((min (y .global_shape [0 ] - 2 , 2 ),) + self .dims [1 :]), axis = 0 )
250265 y [:] = y_centered
251266 if self .edge :
252267 if self .rank == 0 :
@@ -260,34 +275,36 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
260275
261276 @reshaped
262277 def _rmatvec_centered5 (self , x : DistributedArray ) -> DistributedArray :
263- y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes , axis = x .axis , dtype = self .dtype )
278+ ncp = get_module (x .engine )
279+ y = DistributedArray (global_shape = x .global_shape , local_shapes = x .local_shapes ,
280+ axis = x .axis , engine = x .engine , dtype = self .dtype )
264281 y [:] = 0
265282 ghosted_x = x .add_ghost_cells (cells_back = 4 )
266283 y_centered = ghosted_x [2 :- 2 ] / 12.0
267284 if self .rank == self .size - 1 :
268- y_centered = np .append (y_centered , np .zeros ((min (y .global_shape [0 ], 4 ),) + self .dims [1 :]), axis = 0 )
285+ y_centered = ncp .append (y_centered , ncp .zeros ((min (y .global_shape [0 ], 4 ),) + self .dims [1 :]), axis = 0 )
269286 y [:] += y_centered
270287
271288 ghosted_x = x .add_ghost_cells (cells_front = 1 , cells_back = 3 )
272289 y_centered = 2.0 * ghosted_x [2 :- 2 ] / 3.0
273290 if self .rank == 0 :
274- y_centered = np . insert ( y_centered , 0 , np .zeros ((1 ,) + self .dims [1 :]), axis = 0 )
291+ y_centered = ncp . append ( ncp .zeros ((1 ,) + self .dims [1 :]), y_centered , axis = 0 )
275292 if self .rank == self .size - 1 :
276- y_centered = np .append (y_centered , np .zeros ((min (y .global_shape [0 ] - 1 , 3 ),) + self .dims [1 :]), axis = 0 )
293+ y_centered = ncp .append (y_centered , ncp .zeros ((min (y .global_shape [0 ] - 1 , 3 ),) + self .dims [1 :]), axis = 0 )
277294 y [:] -= y_centered
278295
279296 ghosted_x = x .add_ghost_cells (cells_front = 3 , cells_back = 1 )
280297 y_centered = 2.0 * ghosted_x [2 :- 2 ] / 3.0
281298 if self .rank == 0 :
282- y_centered = np . insert ( y_centered , 0 , np .zeros ((min (y .global_shape [0 ], 3 ),) + self .dims [1 :]), axis = 0 )
299+ y_centered = ncp . append ( ncp .zeros ((min (y .global_shape [0 ], 3 ),) + self .dims [1 :]), y_centered , axis = 0 )
283300 if self .rank == self .size - 1 :
284- y_centered = np .append (y_centered , np .zeros ((min (y .global_shape [0 ] - 3 , 1 ),) + self .dims [1 :]), axis = 0 )
301+ y_centered = ncp .append (y_centered , ncp .zeros ((min (y .global_shape [0 ] - 3 , 1 ),) + self .dims [1 :]), axis = 0 )
285302 y [:] += y_centered
286303
287304 ghosted_x = x .add_ghost_cells (cells_front = 4 )
288305 y_centered = ghosted_x [2 :- 2 ] / 12.0
289306 if self .rank == 0 :
290- y_centered = np . insert ( y_centered , 0 , np .zeros ((min (y .global_shape [0 ], 4 ),) + self .dims [1 :]), axis = 0 )
307+ y_centered = ncp . append ( ncp .zeros ((min (y .global_shape [0 ], 4 ),) + self .dims [1 :]), y_centered , axis = 0 )
291308 y [:] -= y_centered
292309 if self .edge :
293310 if self .rank == 0 :
0 commit comments