@@ -485,7 +485,6 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
485485        else :
486486            if  is_cuda_aware_mpi  or  self .engine  ==  "numpy" :
487487                ncp  =  get_module (self .engine )
488-                 # mpi_type = MPI._typedict[send_buf.dtype.char] 
489488                recv_buf  =  ncp .zeros (send_buf .size , dtype = send_buf .dtype )
490489                self .base_comm .Allreduce (send_buf , recv_buf , op )
491490                return  recv_buf 
@@ -505,7 +504,6 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
505504        else :
506505            if  is_cuda_aware_mpi  or  self .engine  ==  "numpy" :
507506                ncp  =  get_module (self .engine )
508-                 # mpi_type = MPI._typedict[send_buf.dtype.char] 
509507                recv_buf  =  ncp .zeros (send_buf .size , dtype = send_buf .dtype )
510508                self .sub_comm .Allreduce (send_buf , recv_buf , op )
511509                return  recv_buf 
@@ -743,6 +741,9 @@ def _compute_vector_norm(self, local_array: NDArray,
743741                recv_buf  =  ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
744742            else :
745743                recv_buf  =  self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MAX )
744+                 # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL 
745+                 # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it.  
746+                 # There may be a way to unify it - may be something to do with how we allocate the recv_buf. 
746747                if  self .base_comm_nccl :
747748                    recv_buf  =  ncp .squeeze (recv_buf , axis = axis )
748749        elif  ord  ==  - ncp .inf :
0 commit comments