@@ -621,9 +621,9 @@ def dot(self, dist_array):
621621 self ._check_mask (dist_array )
622622 ncp = get_module (self .engine )
623623 # Convert to Partition.SCATTER if Partition.BROADCAST
624- x = DistributedArray .to_dist (x = self .local_array ) \
624+ x = DistributedArray .to_dist (x = self .local_array , base_comm_nccl = self . base_comm_nccl ) \
625625 if self .partition in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ] else self
626- y = DistributedArray .to_dist (x = dist_array .local_array ) \
626+ y = DistributedArray .to_dist (x = dist_array .local_array , base_comm_nccl = self . base_comm_nccl ) \
627627 if self .partition in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ] else dist_array
628628 # Flatten the local arrays and calculate dot product
629629 return self ._allreduce_subcomm (ncp .dot (x .local_array .flatten (), y .local_array .flatten ()))
@@ -676,6 +676,7 @@ def zeros_like(self):
676676 """
677677 arr = DistributedArray (global_shape = self .global_shape ,
678678 base_comm = self .base_comm ,
679+ base_comm_nccl = self .base_comm_nccl ,
679680 partition = self .partition ,
680681 axis = self .axis ,
681682 local_shapes = self .local_shapes ,
@@ -696,7 +697,7 @@ def norm(self, ord: Optional[int] = None,
696697 Axis along which vector norm needs to be computed. Defaults to ``-1``
697698 """
698699 # Convert to Partition.SCATTER if Partition.BROADCAST
699- x = DistributedArray .to_dist (x = self .local_array ) \
700+ x = DistributedArray .to_dist (x = self .local_array , base_comm_nccl = self . base_comm_nccl ) \
700701 if self .partition in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ] else self
701702 if axis == - 1 :
702703 # Flatten the local arrays and calculate norm
@@ -711,6 +712,7 @@ def conj(self):
711712 """
712713 conj = DistributedArray (global_shape = self .global_shape ,
713714 base_comm = self .base_comm ,
715+ base_comm_nccl = self .base_comm_nccl ,
714716 partition = self .partition ,
715717 axis = self .axis ,
716718 local_shapes = self .local_shapes ,
@@ -725,6 +727,7 @@ def copy(self):
725727 """
726728 arr = DistributedArray (global_shape = self .global_shape ,
727729 base_comm = self .base_comm ,
730+ base_comm_nccl = self .base_comm_nccl ,
728731 partition = self .partition ,
729732 axis = self .axis ,
730733 local_shapes = self .local_shapes ,
@@ -879,7 +882,8 @@ def asarray(self):
879882 Global Array gathered at all ranks
880883
881884 """
882- return np .hstack ([distarr .asarray ().ravel () for distarr in self .distarrays ])
885+ ncp = get_module (self .distarrays [0 ].engine )
886+ return ncp .hstack ([distarr .asarray ().ravel () for distarr in self .distarrays ])
883887
884888 def _check_stacked_size (self , stacked_array ):
885889 """Check that arrays have consistent size
0 commit comments