@@ -694,14 +694,25 @@ def _compute_vector_norm(self, local_array: NDArray,
694694            recv_buf  =  self ._allreduce_subcomm (ncp .count_nonzero (local_array , axis = axis ).astype (ncp .float64 ))
695695        elif  ord  ==  ncp .inf :
696696            # Calculate max followed by max reduction 
697-             recv_buf  =  self ._allreduce_subcomm (ncp .max (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 ),
698-                                                recv_buf , op = MPI .MAX )
699-             recv_buf  =  ncp .squeeze (recv_buf , axis = axis )
697+             # TODO (tharitt): currently CuPy + MPI does not work well with buffered communication, particularly 
698+             # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs 
699+             send_buf  =  ncp .max (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
700+             if  self .engine == "cupy"  and  self .base_comm_nccl  is  None :
701+                 recv_buf  =  self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MAX )
702+                 recv_buf  =  ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
703+             else :
704+                 recv_buf  =  self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MAX )
705+                 recv_buf  =  ncp .squeeze (recv_buf , axis = axis )
700706        elif  ord  ==  - ncp .inf :
701707            # Calculate min followed by min reduction 
702-             recv_buf  =  self ._allreduce_subcomm (ncp .min (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 ),
703-                                                recv_buf , op = MPI .MIN )
704-             recv_buf  =  ncp .squeeze (recv_buf , axis = axis )
708+             # TODO (tharitt): see the comment above in infinity norm 
709+             send_buf  =  ncp .min (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
710+             if  self .engine  ==  "cupy"  and  self .base_comm_nccl  is  None :
711+                 recv_buf  =  self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MIN )
712+                 recv_buf  =  ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
713+             else :
714+                 recv_buf  =  self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MIN )
715+                 recv_buf  =  ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
705716
706717        else :
707718            recv_buf  =  self ._allreduce_subcomm (ncp .sum (ncp .abs (ncp .float_power (local_array , ord )), axis = axis ))
0 commit comments