Skip to content

Commit d3e63e6

Browse files
authored
Merge pull request #14412 from jczaja/prv-dam-softmax
Softmax for Inference is enabled when ON_INFER is set
2 parents 77ac30e + b361579 commit d3e63e6

File tree

6 files changed

+58
-13
lines changed

6 files changed

+58
-13
lines changed

paddle/fluid/operators/math/softmax.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ namespace paddle {
1919
namespace operators {
2020
namespace math {
2121

22-
template class SoftmaxFunctor<platform::CPUDeviceContext, float>;
23-
template class SoftmaxFunctor<platform::CPUDeviceContext, double>;
22+
template class SoftmaxFunctor<platform::CPUDeviceContext, float, true>;
23+
template class SoftmaxFunctor<platform::CPUDeviceContext, float, false>;
24+
template class SoftmaxFunctor<platform::CPUDeviceContext, double, true>;
25+
template class SoftmaxFunctor<platform::CPUDeviceContext, double, false>;
2426
template class SoftmaxGradFunctor<platform::CPUDeviceContext, float>;
2527
template class SoftmaxGradFunctor<platform::CPUDeviceContext, double>;
2628

paddle/fluid/operators/math/softmax.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,14 @@ template class SoftmaxGradCUDNNFunctor<float>;
9898
template class SoftmaxGradCUDNNFunctor<double>;
9999
template class SoftmaxGradCUDNNFunctor<platform::float16>;
100100

101-
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16>;
102-
template class SoftmaxFunctor<platform::CUDADeviceContext, float>;
103-
template class SoftmaxFunctor<platform::CUDADeviceContext, double>;
101+
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,
102+
false>;
103+
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,
104+
true>;
105+
template class SoftmaxFunctor<platform::CUDADeviceContext, float, false>;
106+
template class SoftmaxFunctor<platform::CUDADeviceContext, double, false>;
107+
template class SoftmaxFunctor<platform::CUDADeviceContext, float, true>;
108+
template class SoftmaxFunctor<platform::CUDADeviceContext, double, true>;
104109
template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>;
105110
template class SoftmaxGradFunctor<platform::CUDADeviceContext, double>;
106111
template class SoftmaxGradFunctor<platform::CUDADeviceContext,

paddle/fluid/operators/math/softmax.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace paddle {
1919
namespace operators {
2020
namespace math {
2121

22-
template <typename DeviceContext, typename T>
22+
template <typename DeviceContext, typename T, bool is_test>
2323
class SoftmaxFunctor {
2424
public:
2525
void operator()(const DeviceContext& context, const framework::Tensor* X,

paddle/fluid/operators/math/softmax_impl.h

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ struct ValueClip {
3232
}
3333
};
3434

35-
template <typename DeviceContext, typename T>
36-
void SoftmaxFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
37-
const framework::Tensor* X,
38-
framework::Tensor* Y) {
35+
template <typename DeviceContext, typename T, bool is_test>
36+
void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
37+
const DeviceContext& context, const framework::Tensor* X,
38+
framework::Tensor* Y) {
3939
auto logits = EigenMatrix<T>::From(*X);
4040
auto softmax = EigenMatrix<T>::From(*Y);
4141

@@ -65,6 +65,39 @@ void SoftmaxFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
6565
.broadcast(one_by_class));
6666
}
6767

68+
template <typename DeviceContext, typename T>
69+
class SoftmaxFunctor<DeviceContext, T, true> {
70+
void operator()(const DeviceContext& context, const framework::Tensor* X,
71+
framework::Tensor* Y) {
72+
auto logits = EigenMatrix<T>::From(*X);
73+
auto softmax = EigenMatrix<T>::From(*Y);
74+
75+
const int kBatchDim = 0;
76+
const int kClassDim = 1;
77+
78+
const int batch_size = logits.dimension(kBatchDim);
79+
const int num_classes = logits.dimension(kClassDim);
80+
81+
Eigen::DSizes<int, 1> along_class(kClassDim);
82+
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
83+
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
84+
85+
auto shifted_logits = (logits -
86+
logits.maximum(along_class)
87+
.eval()
88+
.reshape(batch_by_one)
89+
.broadcast(one_by_class));
90+
91+
softmax.device(*context.eigen_device()) = shifted_logits.exp();
92+
softmax.device(*context.eigen_device()) = (softmax *
93+
softmax.sum(along_class)
94+
.inverse()
95+
.eval()
96+
.reshape(batch_by_one)
97+
.broadcast(one_by_class));
98+
}
99+
};
100+
68101
template <typename DeviceContext, typename T>
69102
void SoftmaxGradFunctor<DeviceContext, T>::operator()(
70103
const DeviceContext& context, const framework::Tensor* y,

paddle/fluid/operators/softmax_op.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
3535
Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1);
3636
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
3737

38-
math::SoftmaxFunctor<DeviceContext, T>()(
38+
#ifdef ON_INFER
39+
math::SoftmaxFunctor<DeviceContext, T, true>()(
3940
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
41+
#else
42+
math::SoftmaxFunctor<DeviceContext, T, false>()(
43+
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
44+
#endif
4045
}
4146
};
4247

paddle/fluid/operators/softmax_with_cross_entropy_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
4242

4343
auto& dev_ctx =
4444
context.template device_context<platform::CPUDeviceContext>();
45-
math::SoftmaxFunctor<platform::CPUDeviceContext, T>()(dev_ctx, logits,
46-
softmax);
45+
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
46+
dev_ctx, logits, softmax);
4747
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
4848
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
4949
context.Attr<int>("ignore_index"));

0 commit comments

Comments
 (0)