diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 908da9e042c969..684dc1e06694eb 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -4206,7 +4206,7 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor { // dx = dout * 0.5 / out __device__ __forceinline__ T operator()(const T dout, const T out) const { - return one_half * dout / out; + return out != T(0) ? one_half * dout / out : T(0); } static constexpr ActBwdOpFwdDeps FwdDeps() {