@@ -340,12 +340,7 @@ def local_shapes(self):
340340 local_shapes : :obj:`list`
341341 """
342342 if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
343- # gather tuple of shapes from every rank and copy from GPU to CPU
344- all_tuples = self ._allgather (self .local_shape ).get ()
345- # NCCL returns the flat array that packs every tuple as 1-dimensional array
346- # unpack each tuple from each rank
347- tuple_len = len (self .local_shape )
348- return [tuple (all_tuples [i : i + tuple_len ]) for i in range (0 , len (all_tuples ), tuple_len )]
343+ return self ._nccl_local_shapes (self .base_comm_nccl )
349344 else :
350345 return self ._allgather (self .local_shape )
351346
@@ -380,12 +375,7 @@ def asarray(self, masked: bool = False):
380375 return self .local_array
381376
382377 if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
383- if masked :
384- all_tuples = self ._allgather_subcomm (self .local_shape ).get ()
385- tuple_len = len (self .local_shape )
386- local_shapes = [tuple (all_tuples [i : i + tuple_len ]) for i in range (0 , len (all_tuples ), tuple_len )]
387- else :
388- local_shapes = self .local_shapes
378+ local_shapes = self ._nccl_local_shapes (self .sub_comm if masked else self .base_comm_nccl )
389379 return nccl_asarray (self .sub_comm if masked else self .base_comm_nccl ,
390380 self .local_array , local_shapes , self .axis )
391381 else :
@@ -560,6 +550,21 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None):
560550 self .base_comm .Recv (buf = recv_buf , source = source , tag = tag )
561551 return recv_buf
562552
553+ def _nccl_local_shapes (self , nccl_comm : NcclCommunicatorType ):
554+ """Get the the list of shapes of every GPU in the communicator
555+ """
556+ # gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU
557+ if nccl_comm == self .sub_comm :
558+ all_tuples = self ._allgather_subcomm (self .local_shape ).get ()
559+ else :
560+ assert (nccl_comm == self .base_comm_nccl )
561+ all_tuples = self ._allgather (self .local_shape ).get ()
562+ # NCCL returns the flat array that packs every tuple as 1-dimensional array
563+ # unpack each tuple from each rank
564+ tuple_len = len (self .local_shape )
565+ local_shapes = [tuple (all_tuples [i : i + tuple_len ]) for i in range (0 , len (all_tuples ), tuple_len )]
566+ return local_shapes
567+
563568 def __neg__ (self ):
564569 arr = DistributedArray (global_shape = self .global_shape ,
565570 base_comm = self .base_comm ,
0 commit comments