@@ -16,6 +16,7 @@ limitations under the License. */
16
16
#include " paddle/fluid/framework/eigen.h"
17
17
#include " paddle/fluid/framework/tensor.h"
18
18
19
+ #include " paddle/fluid/operators/math/blas.h"
19
20
namespace paddle {
20
21
namespace operators {
21
22
namespace math {
@@ -65,36 +66,42 @@ void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
65
66
.broadcast (one_by_class));
66
67
}
67
68
68
- template <typename DeviceContext, typename T >
69
- class SoftmaxFunctor <DeviceContext, T , true > {
69
+ template <typename DeviceContext>
70
+ class SoftmaxFunctor <DeviceContext, float , true > {
70
71
void operator ()(const DeviceContext& context, const framework::Tensor* X,
71
72
framework::Tensor* Y) {
72
- auto logits = EigenMatrix<T>::From (*X);
73
- auto softmax = EigenMatrix<T>::From (*Y);
74
-
73
+ auto in_dims = X->dims ();
74
+ auto out_dims = Y->dims ();
75
+ const float * in_data = X->data <float >();
76
+ float * out_data = Y->data <float >();
75
77
const int kBatchDim = 0 ;
76
78
const int kClassDim = 1 ;
77
-
78
- const int batch_size = logits.dimension (kBatchDim );
79
- const int num_classes = logits.dimension (kClassDim );
80
-
81
- Eigen::DSizes<int , 1 > along_class (kClassDim );
82
- Eigen::DSizes<int , 2 > batch_by_one (batch_size, 1 );
83
- Eigen::DSizes<int , 2 > one_by_class (1 , num_classes);
84
-
85
- auto shifted_logits = (logits -
86
- logits.maximum (along_class)
87
- .eval ()
88
- .reshape (batch_by_one)
89
- .broadcast (one_by_class));
90
-
91
- softmax.device (*context.eigen_device ()) = shifted_logits.exp ();
92
- softmax.device (*context.eigen_device ()) = (softmax *
93
- softmax.sum (along_class)
94
- .inverse ()
95
- .eval ()
96
- .reshape (batch_by_one)
97
- .broadcast (one_by_class));
79
+ // 2D data. Batch x C
80
+ const int batch_size = in_dims[kBatchDim ];
81
+ const int num_classes = in_dims[kClassDim ];
82
+ std::vector<float > entities (batch_size);
83
+ auto blas = math::GetBlas<DeviceContext, float >(context);
84
+ for (int n = 0 ; n < batch_size; ++n) {
85
+ entities[n] = in_data[n * num_classes];
86
+ for (int c = 1 ; c < num_classes; ++c) {
87
+ entities[n] = in_data[n * num_classes + c] > entities[n]
88
+ ? in_data[n * num_classes + c]
89
+ : entities[n];
90
+ }
91
+ for (int c = 0 ; c < num_classes; ++c) {
92
+ out_data[n * num_classes + c] =
93
+ in_data[n * num_classes + c] - entities[n];
94
+ }
95
+ }
96
+
97
+ blas.VEXP (num_classes * batch_size, out_data, out_data);
98
+ for (int n = 0 ; n < batch_size; ++n) {
99
+ entities[n] = out_data[n * num_classes];
100
+ for (int c = 1 ; c < num_classes; ++c) {
101
+ entities[n] += out_data[n * num_classes + c];
102
+ }
103
+ blas.SCAL (num_classes, 1 .0f / entities[n], &out_data[n * num_classes]);
104
+ }
98
105
}
99
106
};
100
107
0 commit comments