@@ -16,6 +16,7 @@ limitations under the License. */
16
16
#include " paddle/fluid/framework/eigen.h"
17
17
#include " paddle/fluid/framework/tensor.h"
18
18
19
+ #include " paddle/fluid/operators/math/blas.h"
19
20
namespace paddle {
20
21
namespace operators {
21
22
namespace math {
@@ -32,8 +33,8 @@ struct ValueClip {
32
33
}
33
34
};
34
35
35
- template <typename DeviceContext, typename T, bool is_test>
36
- void SoftmaxFunctor<DeviceContext, T, is_test>::operator ()(
36
+ template <typename DeviceContext, typename T, bool is_test, typename Enable >
37
+ void SoftmaxFunctor<DeviceContext, T, is_test, Enable >::operator ()(
37
38
const DeviceContext& context, const framework::Tensor* X,
38
39
framework::Tensor* Y) {
39
40
auto logits = EigenMatrix<T>::From (*X);
@@ -65,36 +66,46 @@ void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
65
66
.broadcast (one_by_class));
66
67
}
67
68
68
- template <typename DeviceContext, typename T>
69
- class SoftmaxFunctor <DeviceContext, T, true > {
69
+ template <class DeviceContext >
70
+ using enable_if_CPU = typename std::enable_if<
71
+ std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type;
72
+
73
+ template <typename DeviceContext>
74
+ class SoftmaxFunctor <DeviceContext, float , true , enable_if_CPU<DeviceContext>> {
70
75
void operator ()(const DeviceContext& context, const framework::Tensor* X,
71
76
framework::Tensor* Y) {
72
- auto logits = EigenMatrix<T>::From (*X);
73
- auto softmax = EigenMatrix<T>::From (*Y);
74
-
77
+ auto in_dims = X->dims ();
78
+ auto out_dims = Y->dims ();
79
+ const float * in_data = X->data <float >();
80
+ float * out_data = Y->data <float >();
75
81
const int kBatchDim = 0 ;
76
82
const int kClassDim = 1 ;
77
-
78
- const int batch_size = logits.dimension (kBatchDim );
79
- const int num_classes = logits.dimension (kClassDim );
80
-
81
- Eigen::DSizes<int , 1 > along_class (kClassDim );
82
- Eigen::DSizes<int , 2 > batch_by_one (batch_size, 1 );
83
- Eigen::DSizes<int , 2 > one_by_class (1 , num_classes);
84
-
85
- auto shifted_logits = (logits -
86
- logits.maximum (along_class)
87
- .eval ()
88
- .reshape (batch_by_one)
89
- .broadcast (one_by_class));
90
-
91
- softmax.device (*context.eigen_device ()) = shifted_logits.exp ();
92
- softmax.device (*context.eigen_device ()) = (softmax *
93
- softmax.sum (along_class)
94
- .inverse ()
95
- .eval ()
96
- .reshape (batch_by_one)
97
- .broadcast (one_by_class));
83
+ // 2D data. Batch x C
84
+ const int batch_size = in_dims[kBatchDim ];
85
+ const int num_classes = in_dims[kClassDim ];
86
+ std::vector<float > entities (batch_size);
87
+ auto blas = math::GetBlas<DeviceContext, float >(context);
88
+ for (int n = 0 ; n < batch_size; ++n) {
89
+ entities[n] = in_data[n * num_classes];
90
+ for (int c = 1 ; c < num_classes; ++c) {
91
+ entities[n] = in_data[n * num_classes + c] > entities[n]
92
+ ? in_data[n * num_classes + c]
93
+ : entities[n];
94
+ }
95
+ for (int c = 0 ; c < num_classes; ++c) {
96
+ out_data[n * num_classes + c] =
97
+ in_data[n * num_classes + c] - entities[n];
98
+ }
99
+ }
100
+
101
+ blas.VEXP (num_classes * batch_size, out_data, out_data);
102
+ for (int n = 0 ; n < batch_size; ++n) {
103
+ entities[n] = out_data[n * num_classes];
104
+ for (int c = 1 ; c < num_classes; ++c) {
105
+ entities[n] += out_data[n * num_classes + c];
106
+ }
107
+ blas.SCAL (num_classes, 1 .0f / entities[n], &out_data[n * num_classes]);
108
+ }
98
109
}
99
110
};
100
111
0 commit comments