@@ -380,7 +380,7 @@ def asarray(self, masked: bool = False):
380380 return self .local_array
381381
382382 if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
383- return nccl_asarray (self .sub_comm if masked else self .base_comm ,
383+ return nccl_asarray (self .sub_comm if masked else self .base_comm_nccl ,
384384 self .local_array , self .local_shapes , self .axis )
385385 else :
386386 # Gather all the local arrays and apply concatenation.
@@ -640,9 +640,9 @@ def dot(self, dist_array):
640640 self ._check_mask (dist_array )
641641 ncp = get_module (self .engine )
642642 # Convert to Partition.SCATTER if Partition.BROADCAST
643- x = DistributedArray .to_dist (x = self .local_array , base_comm_nccl = self .base_comm_nccl ) \
643+ x = DistributedArray .to_dist (x = self .local_array , base_comm = self . base_comm , base_comm_nccl = self .base_comm_nccl ) \
644644 if self .partition in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ] else self
645- y = DistributedArray .to_dist (x = dist_array .local_array , base_comm_nccl = self .base_comm_nccl ) \
645+ y = DistributedArray .to_dist (x = dist_array .local_array , base_comm = self . base_comm , base_comm_nccl = self .base_comm_nccl ) \
646646 if self .partition in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ] else dist_array
647647 # Flatten the local arrays and calculate dot product
648648 return self ._allreduce_subcomm (ncp .dot (x .local_array .flatten (), y .local_array .flatten ()))
@@ -716,7 +716,7 @@ def norm(self, ord: Optional[int] = None,
716716 Axis along which vector norm needs to be computed. Defaults to ``-1``
717717 """
718718 # Convert to Partition.SCATTER if Partition.BROADCAST
719- x = DistributedArray .to_dist (x = self .local_array , base_comm_nccl = self .base_comm_nccl ) \
719+ x = DistributedArray .to_dist (x = self .local_array , base_comm = self . base_comm , base_comm_nccl = self .base_comm_nccl ) \
720720 if self .partition in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ] else self
721721 if axis == - 1 :
722722 # Flatten the local arrays and calculate norm
0 commit comments