@@ -340,7 +340,7 @@ def local_shapes(self):
340340 local_shapes : :obj:`list`
341341 """
342342 if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
343- return self ._nccl_local_shapes (self . base_comm_nccl )
343+ return self ._nccl_local_shapes (False )
344344 else :
345345 return self ._allgather (self .local_shape )
346346
@@ -375,9 +375,8 @@ def asarray(self, masked: bool = False):
375375 return self .local_array
376376
377377 if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
378- local_shapes = self ._nccl_local_shapes (self .sub_comm if masked else self .base_comm_nccl )
379378 return nccl_asarray (self .sub_comm if masked else self .base_comm_nccl ,
380- self .local_array , local_shapes , self .axis )
379+ self .local_array , self . _nccl_local_shapes ( masked ) , self .axis )
381380 else :
382381 # Gather all the local arrays and apply concatenation.
383382 if masked :
@@ -550,14 +549,13 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None):
550549 self .base_comm .Recv (buf = recv_buf , source = source , tag = tag )
551550 return recv_buf
552551
553- def _nccl_local_shapes (self , nccl_comm : NcclCommunicatorType ):
552+ def _nccl_local_shapes (self , masked : bool ):
554553 """Get the the list of shapes of every GPU in the communicator
555554 """
556555 # gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU
557- if nccl_comm == self . sub_comm :
556+ if masked :
558557 all_tuples = self ._allgather_subcomm (self .local_shape ).get ()
559558 else :
560- assert (nccl_comm == self .base_comm_nccl )
561559 all_tuples = self ._allgather (self .local_shape ).get ()
562560 # NCCL returns the flat array that packs every tuple as 1-dimensional array
563561 # unpack each tuple from each rank
0 commit comments