Skip to content

Commit f2c9e4c

Browse files
committed
feature: enable cupy in to_dist
1 parent c8098e8 commit f2c9e4c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from enum import Enum
66

77
from pylops.utils import DTypeLike, NDArray
8-
from pylops.utils.backend import get_module
8+
from pylops.utils.backend import get_module, get_array_module, get_module_name
99

1010

1111
class Partition(Enum):
@@ -294,6 +294,7 @@ def to_dist(cls, x: NDArray,
294294
partition=partition,
295295
axis=axis,
296296
local_shapes=local_shapes,
297+
engine=get_module_name(get_array_module(x)),
297298
dtype=x.dtype)
298299
if partition == Partition.BROADCAST:
299300
dist_array[:] = x

0 commit comments

Comments
 (0)