Skip to content

Commit 6d6f9dc

Browse files
committed
adding test for linearop, stackedarray, and stackedop in nccl
1 parent fb14c86 commit 6d6f9dc

File tree

4 files changed

+820
-4
lines changed

4 files changed

+820
-4
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -621,9 +621,9 @@ def dot(self, dist_array):
621621
self._check_mask(dist_array)
622622
ncp = get_module(self.engine)
623623
# Convert to Partition.SCATTER if Partition.BROADCAST
624-
x = DistributedArray.to_dist(x=self.local_array) \
624+
x = DistributedArray.to_dist(x=self.local_array, base_comm_nccl=self.base_comm_nccl) \
625625
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
626-
y = DistributedArray.to_dist(x=dist_array.local_array) \
626+
y = DistributedArray.to_dist(x=dist_array.local_array, base_comm_nccl=self.base_comm_nccl) \
627627
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array
628628
# Flatten the local arrays and calculate dot product
629629
return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten()))
@@ -676,6 +676,7 @@ def zeros_like(self):
676676
"""
677677
arr = DistributedArray(global_shape=self.global_shape,
678678
base_comm=self.base_comm,
679+
base_comm_nccl=self.base_comm_nccl,
679680
partition=self.partition,
680681
axis=self.axis,
681682
local_shapes=self.local_shapes,
@@ -696,7 +697,7 @@ def norm(self, ord: Optional[int] = None,
696697
Axis along which vector norm needs to be computed. Defaults to ``-1``
697698
"""
698699
# Convert to Partition.SCATTER if Partition.BROADCAST
699-
x = DistributedArray.to_dist(x=self.local_array) \
700+
x = DistributedArray.to_dist(x=self.local_array, base_comm_nccl=self.base_comm_nccl) \
700701
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
701702
if axis == -1:
702703
# Flatten the local arrays and calculate norm
@@ -711,6 +712,7 @@ def conj(self):
711712
"""
712713
conj = DistributedArray(global_shape=self.global_shape,
713714
base_comm=self.base_comm,
715+
base_comm_nccl=self.base_comm_nccl,
714716
partition=self.partition,
715717
axis=self.axis,
716718
local_shapes=self.local_shapes,
@@ -725,6 +727,7 @@ def copy(self):
725727
"""
726728
arr = DistributedArray(global_shape=self.global_shape,
727729
base_comm=self.base_comm,
730+
base_comm_nccl=self.base_comm_nccl,
728731
partition=self.partition,
729732
axis=self.axis,
730733
local_shapes=self.local_shapes,
@@ -879,7 +882,8 @@ def asarray(self):
879882
Global Array gathered at all ranks
880883
881884
"""
882-
return np.hstack([distarr.asarray().ravel() for distarr in self.distarrays])
885+
ncp = get_module(self.distarrays[0].engine)
886+
return ncp.hstack([distarr.asarray().ravel() for distarr in self.distarrays])
883887

884888
def _check_stacked_size(self, stacked_array):
885889
"""Check that arrays have consistent size

0 commit comments

Comments
 (0)