Skip to content

Commit 37f7dbe

Browse files
[Precision Depth Alignment] paddle.log_sigmoid (PaddlePaddle#75898)
* accuracy_stable_log_sigmoid * fix test_activation_stride_op.py
1 parent 6eb5588 commit 37f7dbe

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5111,32 +5111,39 @@ struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
51115111
MPType zero = static_cast<MPType>(0.0f);
51125112

51135113
// logsigmoid(x) = log(1 / (1 + exp(-x)))
5114-
// For numerical stability,
5115-
// logsigmoid(x) =
5116-
// - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
5114+
// Use the numerically stable:
5115+
// log_sigmoid(x) = min(0, x) - log1p(exp(-abs(x)))
51175116
__device__ __forceinline__ T operator()(const T arg_x) const {
51185117
MPType x = static_cast<MPType>(arg_x);
5119-
MPType temp = x > zero ? zero : -x;
5120-
return static_cast<T>(-temp - log(exp(-temp) + exp(-x - temp)));
5118+
MPType min0 = (x < zero) ? x : zero;
5119+
MPType abs_x = abs(x);
5120+
return static_cast<T>(min0 - log1p_local(exp(-abs_x)));
51215121
}
51225122
};
51235123

51245124
template <typename T>
51255125
struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
51265126
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
51275127
MPType zero = static_cast<MPType>(0.0f);
5128+
MPType one = static_cast<MPType>(1.0f);
51285129

51295130
// dx = dout * exp(-x) / (1 + exp(-x))
5130-
// For numerical stability:
5131-
// dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
5132-
// 0)))
5131+
// Use stable backward:
5132+
// grad = dout * (max_deriv - sign * (z / (1 + z)))
5133+
// where z = exp(-abs(x)), max_deriv = (x < 0) ? 1 : 0, sign = (x < 0) ? 1 :
5134+
// -1
51335135
__device__ __forceinline__ T operator()(const T arg_dout,
51345136
const T arg_x) const {
51355137
MPType dout = static_cast<MPType>(arg_dout);
51365138
MPType x = static_cast<MPType>(arg_x);
5137-
MPType temp1 = x > zero ? zero : -x;
5138-
MPType temp2 = exp(-x - temp1);
5139-
return static_cast<T>(dout * (temp2 / (exp(-temp1) + temp2)));
5139+
5140+
// in_negative, max_deriv, sign
5141+
const bool in_negative = (x < zero);
5142+
const MPType max_deriv = in_negative ? one : zero;
5143+
const MPType sign = in_negative ? one : -one;
5144+
5145+
MPType z = exp(-abs(x));
5146+
return static_cast<T>(dout * (max_deriv - sign * (z / (one + z))));
51405147
}
51415148

51425149
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
@@ -5146,19 +5153,25 @@ template <typename T>
51465153
struct CudaLogSigmoidGradFunctor<ComplexType<T>>
51475154
: public BaseActivationFunctor<ComplexType<T>> {
51485155
ComplexType<T> zero = static_cast<ComplexType<T>>(0.0f);
5156+
ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);
51495157

51505158
// dx = dout * exp(-x) / (1 + exp(-x))
5151-
// For numerical stability:
5152-
// dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
5153-
// 0)))
5159+
// Use stable backward:
5160+
// grad = dout * (max_deriv - sign * (z / (1 + z)))
5161+
// where z = exp(-abs(x)), max_deriv = (x < 0) ? 1 : 0, sign = (x < 0) ? 1 :
5162+
// -1
51545163
__device__ __forceinline__ ComplexType<T> operator()(
51555164
const ComplexType<T> arg_dout, const ComplexType<T> arg_x) const {
51565165
ComplexType<T> dout = static_cast<ComplexType<T>>(arg_dout);
51575166
ComplexType<T> x = static_cast<ComplexType<T>>(arg_x);
5158-
ComplexType<T> temp1 = x > zero ? zero : -x;
5159-
ComplexType<T> temp2 = exp(-x - temp1);
5160-
return static_cast<ComplexType<T>>(dout *
5161-
conj(temp2 / (exp(-temp1) + temp2)));
5167+
5168+
// in_negative, max_deriv, sign
5169+
const bool in_negative = (x < zero);
5170+
const ComplexType<T> max_deriv = in_negative ? one : zero;
5171+
const ComplexType<T> sign = in_negative ? one : -one;
5172+
5173+
ComplexType<T> z = exp(-abs(x));
5174+
return static_cast<T>(dout * conj(max_deriv - sign * (z / (one + z))));
51625175
}
51635176

51645177
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }

test/legacy_test/test_activation_stride_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def ref_sigmoid(x):
280280

281281

282282
def ref_log_sigmoid(x):
283-
out = np.log(1 / (1 + np.exp(-x)))
283+
out = -np.log1p(np.exp(-x))
284284
return out
285285

286286

0 commit comments

Comments
 (0)