@@ -5111,32 +5111,39 @@ struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
51115111 MPType zero = static_cast <MPType>(0 .0f );
51125112
51135113 // logsigmoid(x) = log(1 / (1 + exp(-x)))
5114- // For numerical stability,
5115- // logsigmoid(x) =
5116- // - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
5114+ // Use the numerically stable:
5115+ // log_sigmoid(x) = min(0, x) - log1p(exp(-abs(x)))
51175116 __device__ __forceinline__ T operator ()(const T arg_x) const {
51185117 MPType x = static_cast <MPType>(arg_x);
5119- MPType temp = x > zero ? zero : -x;
5120- return static_cast <T>(-temp - log (exp (-temp) + exp (-x - temp)));
5118+ MPType min0 = (x < zero) ? x : zero;
5119+ MPType abs_x = abs (x);
5120+ return static_cast <T>(min0 - log1p_local (exp (-abs_x)));
51215121 }
51225122};
51235123
51245124template <typename T>
51255125struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor <T> {
51265126 using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
51275127 MPType zero = static_cast <MPType>(0 .0f );
5128+ MPType one = static_cast <MPType>(1 .0f );
51285129
51295130 // dx = dout * exp(-x) / (1 + exp(-x))
5130- // For numerical stability:
5131- // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
5132- // 0)))
5131+ // Use stable backward:
5132+ // grad = dout * (max_deriv - sign * (z / (1 + z)))
5133+ // where z = exp(-abs(x)), max_deriv = (x < 0) ? 1 : 0, sign = (x < 0) ? 1 :
5134+ // -1
51335135 __device__ __forceinline__ T operator ()(const T arg_dout,
51345136 const T arg_x) const {
51355137 MPType dout = static_cast <MPType>(arg_dout);
51365138 MPType x = static_cast <MPType>(arg_x);
5137- MPType temp1 = x > zero ? zero : -x;
5138- MPType temp2 = exp (-x - temp1);
5139- return static_cast <T>(dout * (temp2 / (exp (-temp1) + temp2)));
5139+
5140+ // in_negative, max_deriv, sign
5141+ const bool in_negative = (x < zero);
5142+ const MPType max_deriv = in_negative ? one : zero;
5143+ const MPType sign = in_negative ? one : -one;
5144+
5145+ MPType z = exp (-abs (x));
5146+ return static_cast <T>(dout * (max_deriv - sign * (z / (one + z))));
51405147 }
51415148
51425149 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
@@ -5146,19 +5153,25 @@ template <typename T>
51465153struct CudaLogSigmoidGradFunctor <ComplexType<T>>
51475154 : public BaseActivationFunctor<ComplexType<T>> {
51485155 ComplexType<T> zero = static_cast <ComplexType<T>>(0 .0f );
5156+ ComplexType<T> one = static_cast <ComplexType<T>>(1 .0f );
51495157
51505158 // dx = dout * exp(-x) / (1 + exp(-x))
5151- // For numerical stability:
5152- // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
5153- // 0)))
5159+ // Use stable backward:
5160+ // grad = dout * (max_deriv - sign * (z / (1 + z)))
5161+ // where z = exp(-abs(x)), max_deriv = (x < 0) ? 1 : 0, sign = (x < 0) ? 1 :
5162+ // -1
51545163 __device__ __forceinline__ ComplexType<T> operator ()(
51555164 const ComplexType<T> arg_dout, const ComplexType<T> arg_x) const {
51565165 ComplexType<T> dout = static_cast <ComplexType<T>>(arg_dout);
51575166 ComplexType<T> x = static_cast <ComplexType<T>>(arg_x);
5158- ComplexType<T> temp1 = x > zero ? zero : -x;
5159- ComplexType<T> temp2 = exp (-x - temp1);
5160- return static_cast <ComplexType<T>>(dout *
5161- conj (temp2 / (exp (-temp1) + temp2)));
5167+
5168+ // in_negative, max_deriv, sign
5169+ const bool in_negative = (x < zero);
5170+ const ComplexType<T> max_deriv = in_negative ? one : zero;
5171+ const ComplexType<T> sign = in_negative ? one : -one;
5172+
5173+ ComplexType<T> z = exp (-abs (x));
5174+ return static_cast <T>(dout * conj (max_deriv - sign * (z / (one + z))));
51625175 }
51635176
51645177 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
0 commit comments