@@ -485,7 +485,6 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
485485 else :
486486 if is_cuda_aware_mpi or self .engine == "numpy" :
487487 ncp = get_module (self .engine )
488- # mpi_type = MPI._typedict[send_buf.dtype.char]
489488 recv_buf = ncp .zeros (send_buf .size , dtype = send_buf .dtype )
490489 self .base_comm .Allreduce (send_buf , recv_buf , op )
491490 return recv_buf
@@ -505,7 +504,6 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
505504 else :
506505 if is_cuda_aware_mpi or self .engine == "numpy" :
507506 ncp = get_module (self .engine )
508- # mpi_type = MPI._typedict[send_buf.dtype.char]
509507 recv_buf = ncp .zeros (send_buf .size , dtype = send_buf .dtype )
510508 self .sub_comm .Allreduce (send_buf , recv_buf , op )
511509 return recv_buf
@@ -743,6 +741,9 @@ def _compute_vector_norm(self, local_array: NDArray,
743741 recv_buf = ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
744742 else :
745743 recv_buf = self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MAX )
744+ # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL
745+ # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it.
746+ # There may be a way to unify it - may be something to do with how we allocate the recv_buf.
746747 if self .base_comm_nccl :
747748 recv_buf = ncp .squeeze (recv_buf , axis = axis )
748749 elif ord == - ncp .inf :
0 commit comments