Skip to content

Commit 6490bb2

Browse files
authored
Merge pull request #14337 from jczaja/prv-dam-softmax
Softmax op optimization for inference
2 parents 9f68e9a + 03299ed commit 6490bb2

File tree

7 files changed

+68
-15
lines changed

7 files changed

+68
-15
lines changed

paddle/fluid/operators/math/softmax.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ namespace paddle {
1919
namespace operators {
2020
namespace math {
2121

22-
template class SoftmaxFunctor<platform::CPUDeviceContext, float>;
23-
template class SoftmaxFunctor<platform::CPUDeviceContext, double>;
22+
template class SoftmaxFunctor<platform::CPUDeviceContext, float, true>;
23+
template class SoftmaxFunctor<platform::CPUDeviceContext, float, false>;
24+
template class SoftmaxFunctor<platform::CPUDeviceContext, double, true>;
25+
template class SoftmaxFunctor<platform::CPUDeviceContext, double, false>;
2426
template class SoftmaxGradFunctor<platform::CPUDeviceContext, float>;
2527
template class SoftmaxGradFunctor<platform::CPUDeviceContext, double>;
2628

paddle/fluid/operators/math/softmax.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,14 @@ template class SoftmaxGradCUDNNFunctor<float>;
9898
template class SoftmaxGradCUDNNFunctor<double>;
9999
template class SoftmaxGradCUDNNFunctor<platform::float16>;
100100

101-
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16>;
102-
template class SoftmaxFunctor<platform::CUDADeviceContext, float>;
103-
template class SoftmaxFunctor<platform::CUDADeviceContext, double>;
101+
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,
102+
false>;
103+
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,
104+
true>;
105+
template class SoftmaxFunctor<platform::CUDADeviceContext, float, false>;
106+
template class SoftmaxFunctor<platform::CUDADeviceContext, double, false>;
107+
template class SoftmaxFunctor<platform::CUDADeviceContext, float, true>;
108+
template class SoftmaxFunctor<platform::CUDADeviceContext, double, true>;
104109
template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>;
105110
template class SoftmaxGradFunctor<platform::CUDADeviceContext, double>;
106111
template class SoftmaxGradFunctor<platform::CUDADeviceContext,

paddle/fluid/operators/math/softmax.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace paddle {
1919
namespace operators {
2020
namespace math {
2121

22-
template <typename DeviceContext, typename T>
22+
template <typename DeviceContext, typename T, bool is_test>
2323
class SoftmaxFunctor {
2424
public:
2525
void operator()(const DeviceContext& context, const framework::Tensor* X,

paddle/fluid/operators/math/softmax_impl.h

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ struct ValueClip {
3232
}
3333
};
3434

35-
template <typename DeviceContext, typename T>
36-
void SoftmaxFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
37-
const framework::Tensor* X,
38-
framework::Tensor* Y) {
35+
template <typename DeviceContext, typename T, bool is_test>
36+
void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
37+
const DeviceContext& context, const framework::Tensor* X,
38+
framework::Tensor* Y) {
3939
auto logits = EigenMatrix<T>::From(*X);
4040
auto softmax = EigenMatrix<T>::From(*Y);
4141

@@ -65,6 +65,39 @@ void SoftmaxFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
6565
.broadcast(one_by_class));
6666
}
6767

68+
template <typename DeviceContext, typename T>
69+
class SoftmaxFunctor<DeviceContext, T, true> {
70+
void operator()(const DeviceContext& context, const framework::Tensor* X,
71+
framework::Tensor* Y) {
72+
auto logits = EigenMatrix<T>::From(*X);
73+
auto softmax = EigenMatrix<T>::From(*Y);
74+
75+
const int kBatchDim = 0;
76+
const int kClassDim = 1;
77+
78+
const int batch_size = logits.dimension(kBatchDim);
79+
const int num_classes = logits.dimension(kClassDim);
80+
81+
Eigen::DSizes<int, 1> along_class(kClassDim);
82+
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
83+
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
84+
85+
auto shifted_logits = (logits -
86+
logits.maximum(along_class)
87+
.eval()
88+
.reshape(batch_by_one)
89+
.broadcast(one_by_class));
90+
91+
softmax.device(*context.eigen_device()) = shifted_logits.exp();
92+
softmax.device(*context.eigen_device()) = (softmax *
93+
softmax.sum(along_class)
94+
.inverse()
95+
.eval()
96+
.reshape(batch_by_one)
97+
.broadcast(one_by_class));
98+
}
99+
};
100+
68101
template <typename DeviceContext, typename T>
69102
void SoftmaxGradFunctor<DeviceContext, T>::operator()(
70103
const DeviceContext& context, const framework::Tensor* y,

paddle/fluid/operators/softmax_op.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,14 @@ class SoftmaxKernel : public framework::OpKernel<T> {
3535
Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1);
3636
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
3737

38-
math::SoftmaxFunctor<DeviceContext, T>()(
39-
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
38+
const bool is_test = context.Attr<bool>("is_test");
39+
if (is_test == true) {
40+
math::SoftmaxFunctor<DeviceContext, T, true>()(
41+
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
42+
} else {
43+
math::SoftmaxFunctor<DeviceContext, T, false>()(
44+
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
45+
}
4046
}
4147
};
4248

paddle/fluid/operators/softmax_with_cross_entropy_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
4242

4343
auto& dev_ctx =
4444
context.template device_context<platform::CPUDeviceContext>();
45-
math::SoftmaxFunctor<platform::CPUDeviceContext, T>()(dev_ctx, logits,
46-
softmax);
45+
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
46+
dev_ctx, logits, softmax);
4747
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
4848
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
4949
context.Attr<int>("ignore_index"));

python/paddle/fluid/tests/unittests/test_softmax_op.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def setUp(self):
3535
self.op_type = "softmax"
3636
self.use_cudnn = False
3737
self.use_mkldnn = False
38+
self.is_test = False
3839
self.dtype = np.float32
3940
self.init_kernel_type()
4041
self.shape = self.get_x_shape()
@@ -48,7 +49,8 @@ def setUp(self):
4849
self.outputs = {'Out': out}
4950
self.attrs = {
5051
'use_cudnn': self.use_cudnn,
51-
'use_mkldnn': self.use_mkldnn
52+
'use_mkldnn': self.use_mkldnn,
53+
'is_test': self.is_test
5254
}
5355

5456
def init_kernel_type(self):
@@ -144,6 +146,11 @@ def get_x_shape(self):
144146
return [2, 3, 4, 5]
145147

146148

149+
class TestSoftmaxInference(TestSoftmaxOp):
150+
def init_kernel_type(self):
151+
self.is_test = True
152+
153+
147154
class TestSoftmaxMKLDNNOp(TestSoftmaxOp):
148155
def init_kernel_type(self):
149156
self.use_mkldnn = True

0 commit comments

Comments
 (0)