@@ -554,7 +554,7 @@ struct CosFunctor : public BaseActivationFunctor<T> {
554554template <typename T>
555555struct LogitFunctor {
556556 template <typename Device, typename X, typename Out, typename P>
557- void operator ()(Device d, X x, Out out, P p, float eps) const {
557+ void operator ()(Device d, X x, Out out, P p, double eps) const {
558558 // logit(x) = ln(x/(1-x))
559559 auto tmp_x =
560560 (x.cwiseMin (static_cast <T>(1.0 - eps))).cwiseMax (static_cast <T>(eps));
@@ -1268,7 +1268,7 @@ struct AtanGradFunctor<ComplexType<T>>
12681268template <typename T>
12691269struct LogitGradFunctor {
12701270 template <typename Device, typename X, typename dOut, typename dX, typename P>
1271- void operator ()(Device d, X x, dOut dout, dX dx, P p, float eps) const {
1271+ void operator ()(Device d, X x, dOut dout, dX dx, P p, double eps) const {
12721272 // logit(x)' = 1/(x*(1-x))
12731273 if (!eps) {
12741274 dx.device (d) = (x < static_cast <T>(0.0 ) || x > static_cast <T>(1.0 ))
@@ -3422,15 +3422,14 @@ struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
34223422
34233423template <typename T>
34243424struct CudaLogitFunctor : public BaseActivationFunctor <T> {
3425+ using AttrPair = std::vector<std::pair<const char *, double *>>;
34253426 using MT = typename phi::dtype::MPTypeTrait<T>::Type;
34263427
34273428 MT zero = static_cast <MT>(0 .0f );
34283429 MT one = static_cast <MT>(1 .0f );
3429- float eps;
3430+ double eps;
34303431
3431- typename BaseActivationFunctor<T>::AttrPair GetAttrs () {
3432- return {{" eps" , &eps}};
3433- }
3432+ typename CudaLogitFunctor<T>::AttrPair GetAttrs () { return {{" eps" , &eps}}; }
34343433
34353434 // logit(x) = ln(x/(1-x))
34363435 __device__ __forceinline__ T operator ()(const T arg_x) const {
@@ -3449,13 +3448,14 @@ struct CudaLogitFunctor : public BaseActivationFunctor<T> {
34493448
34503449template <typename T>
34513450struct CudaLogitGradFunctor : public BaseActivationFunctor <T> {
3451+ using AttrPair = std::vector<std::pair<const char *, double *>>;
34523452 using MT = typename phi::dtype::MPTypeTrait<T>::Type;
34533453
3454- float eps;
3454+ double eps;
34553455 MT zero = static_cast <MT>(0 .0f );
34563456 MT one = static_cast <MT>(1 .0f );
34573457
3458- typename BaseActivationFunctor <T>::AttrPair GetAttrs () {
3458+ typename CudaLogitGradFunctor <T>::AttrPair GetAttrs () {
34593459 return {{" eps" , &eps}};
34603460 }
34613461 // logit(x)' = 1/(x*(1-x))
0 commit comments