@@ -646,9 +646,9 @@ def dot(self, dist_array):
646646 self ._check_mask (dist_array )
647647 ncp = get_module (self .engine )
648648 # Convert to Partition.SCATTER if Partition.BROADCAST
649- x = DistributedArray .to_dist (x = self .local_array ) \
649+ x = DistributedArray .to_dist (x = self .local_array , base_comm = self . base_comm , base_comm_nccl = self . base_comm_nccl ) \
650650 if self .partition in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ] else self
651- y = DistributedArray .to_dist (x = dist_array .local_array ) \
651+ y = DistributedArray .to_dist (x = dist_array .local_array , base_comm = self . base_comm , base_comm_nccl = self . base_comm_nccl ) \
652652 if self .partition in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ] else dist_array
653653 # Flatten the local arrays and calculate dot product
654654 return self ._allreduce_subcomm (ncp .dot (x .local_array .flatten (), y .local_array .flatten ()))
@@ -701,6 +701,7 @@ def zeros_like(self):
701701 """
702702 arr = DistributedArray (global_shape = self .global_shape ,
703703 base_comm = self .base_comm ,
704+ base_comm_nccl = self .base_comm_nccl ,
704705 partition = self .partition ,
705706 axis = self .axis ,
706707 local_shapes = self .local_shapes ,
@@ -721,7 +722,7 @@ def norm(self, ord: Optional[int] = None,
721722 Axis along which vector norm needs to be computed. Defaults to ``-1``
722723 """
723724 # Convert to Partition.SCATTER if Partition.BROADCAST
724- x = DistributedArray .to_dist (x = self .local_array ) \
725+ x = DistributedArray .to_dist (x = self .local_array , base_comm = self . base_comm , base_comm_nccl = self . base_comm_nccl ) \
725726 if self .partition in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ] else self
726727 if axis == - 1 :
727728 # Flatten the local arrays and calculate norm
@@ -736,6 +737,7 @@ def conj(self):
736737 """
737738 conj = DistributedArray (global_shape = self .global_shape ,
738739 base_comm = self .base_comm ,
740+ base_comm_nccl = self .base_comm_nccl ,
739741 partition = self .partition ,
740742 axis = self .axis ,
741743 local_shapes = self .local_shapes ,
@@ -750,6 +752,7 @@ def copy(self):
750752 """
751753 arr = DistributedArray (global_shape = self .global_shape ,
752754 base_comm = self .base_comm ,
755+ base_comm_nccl = self .base_comm_nccl ,
753756 partition = self .partition ,
754757 axis = self .axis ,
755758 local_shapes = self .local_shapes ,
@@ -905,7 +908,8 @@ def asarray(self):
905908 Global Array gathered at all ranks
906909
907910 """
908- return np .hstack ([distarr .asarray ().ravel () for distarr in self .distarrays ])
911+ ncp = get_module (self .distarrays [0 ].engine )
912+ return ncp .hstack ([distarr .asarray ().ravel () for distarr in self .distarrays ])
909913
910914 def _check_stacked_size (self , stacked_array ):
911915 """Check that arrays have consistent size
0 commit comments