@@ -340,12 +340,7 @@ def local_shapes(self):
340340 local_shapes : :obj:`list`
341341 """
342342 if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
343- # gather tuple of shapes from every rank and copy from GPU to CPU
344- all_tuples = self ._allgather (self .local_shape ).get ()
345- # NCCL returns the flat array that packs every tuple as 1-dimensional array
346- # unpack each tuple from each rank
347- tuple_len = len (self .local_shape )
348- return [tuple (all_tuples [i : i + tuple_len ]) for i in range (0 , len (all_tuples ), tuple_len )]
343+ return self ._nccl_local_shapes (False )
349344 else :
350345 return self ._allgather (self .local_shape )
351346
@@ -380,8 +375,8 @@ def asarray(self, masked: bool = False):
380375 return self .local_array
381376
382377 if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
383- return nccl_asarray (self .sub_comm if masked else self .base_comm ,
384- self .local_array , self .local_shapes , self .axis )
378+ return nccl_asarray (self .sub_comm if masked else self .base_comm_nccl ,
379+ self .local_array , self ._nccl_local_shapes ( masked ) , self .axis )
385380 else :
386381 # Gather all the local arrays and apply concatenation.
387382 if masked :
@@ -554,6 +549,20 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None):
554549 self .base_comm .Recv (buf = recv_buf , source = source , tag = tag )
555550 return recv_buf
556551
552+ def _nccl_local_shapes (self , masked : bool ):
553+ """Get the the list of shapes of every GPU in the communicator
554+ """
555+ # gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU
556+ if masked :
557+ all_tuples = self ._allgather_subcomm (self .local_shape ).get ()
558+ else :
559+ all_tuples = self ._allgather (self .local_shape ).get ()
560+ # NCCL returns the flat array that packs every tuple as 1-dimensional array
561+ # unpack each tuple from each rank
562+ tuple_len = len (self .local_shape )
563+ local_shapes = [tuple (all_tuples [i : i + tuple_len ]) for i in range (0 , len (all_tuples ), tuple_len )]
564+ return local_shapes
565+
557566 def __neg__ (self ):
558567 arr = DistributedArray (global_shape = self .global_shape ,
559568 base_comm = self .base_comm ,
0 commit comments