Skip to content

Commit 9c04583

Browse files
committed
feat: ensure y arrays are created on same engine as x
1 parent 3738a06 commit 9c04583

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
335335
local_shapes=[(self.N * c) for c in self._rank_col_lens],
336336
mask=x.mask,
337337
partition=Partition.SCATTER,
338+
engine=x.engine,
338339
dtype=output_dtype,
339340
base_comm=self.base_comm
340341
)
@@ -372,6 +373,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
372373
local_shapes=[self.K * c for c in self._rank_col_lens],
373374
mask=x.mask,
374375
partition=Partition.SCATTER,
376+
engine=x.engine,
375377
dtype=output_dtype,
376378
base_comm=self.base_comm
377379
)
@@ -573,6 +575,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
573575
mask=x.mask,
574576
local_shapes=local_shapes,
575577
partition=Partition.SCATTER,
578+
engine=x.engine,
576579
dtype=output_dtype,
577580
base_comm=self.base_comm)
578581

@@ -638,6 +641,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
638641
mask=x.mask,
639642
local_shapes=local_shapes,
640643
partition=Partition.SCATTER,
644+
engine=x.engine,
641645
dtype=output_dtype,
642646
base_comm=self.base_comm
643647
)

0 commit comments

Comments
 (0)