@@ -694,14 +694,25 @@ def _compute_vector_norm(self, local_array: NDArray,
694694 recv_buf = self ._allreduce_subcomm (ncp .count_nonzero (local_array , axis = axis ).astype (ncp .float64 ))
695695 elif ord == ncp .inf :
696696 # Calculate max followed by max reduction
697- recv_buf = self ._allreduce_subcomm (ncp .max (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 ),
698- recv_buf , op = MPI .MAX )
699- recv_buf = ncp .squeeze (recv_buf , axis = axis )
697+ # TODO (tharitt): currently CuPy + MPI does not work well with buffered communication, particularly
698+ # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs
699+ send_buf = ncp .max (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
700+ if self .engine == "cupy" and self .base_comm_nccl is None :
701+ recv_buf = self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MAX )
702+ recv_buf = ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
703+ else :
704+ recv_buf = self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MAX )
705+ recv_buf = ncp .squeeze (recv_buf , axis = axis )
700706 elif ord == - ncp .inf :
701707 # Calculate min followed by min reduction
702- recv_buf = self ._allreduce_subcomm (ncp .min (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 ),
703- recv_buf , op = MPI .MIN )
704- recv_buf = ncp .squeeze (recv_buf , axis = axis )
708+ # TODO (tharitt): see the comment above in infinity norm
709+ send_buf = ncp .min (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
710+ if self .engine == "cupy" and self .base_comm_nccl is None :
711+ recv_buf = self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MIN )
712+ recv_buf = ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
713+ else :
714+ recv_buf = self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MIN )
715+ recv_buf = ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
705716
706717 else :
707718 recv_buf = self ._allreduce_subcomm (ncp .sum (ncp .abs (ncp .float_power (local_array , ord )), axis = axis ))
0 commit comments