Skip to content

Commit 1b894e4

Browse files
authored
Merge pull request #14437 from jczaja/prv-softmax-mkl
Introducing MKL to softmax for inference
2 parents a94a735 + 9b0eae3 commit 1b894e4

File tree

4 files changed

+53
-38
lines changed

4 files changed

+53
-38
lines changed

CMakeLists.txt

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,14 @@ set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
302302
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
303303
set(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
304304

305+
if (ON_INFER)
306+
message(STATUS "On inference mode, will take place some specific optimization.")
307+
add_definitions(-DPADDLE_ON_INFERENCE)
308+
else()
309+
#TODO(luotao), combine this warning with `make inference_lib_dist` command.
310+
message(WARNING "On inference mode, will take place some specific optimization. Turn on the ON_INFER flag when building inference_lib only.")
311+
endif()
312+
305313
add_subdirectory(paddle)
306314
if(WITH_PYTHON)
307315
add_subdirectory(python)
@@ -312,10 +320,3 @@ if(WITH_DOC)
312320
find_python_module(recommonmark REQUIRED)
313321
add_subdirectory(doc)
314322
endif()
315-
316-
if (ON_INFER)
317-
message(STATUS "On inference mode, will take place some specific optimization.")
318-
else()
319-
#TODO(luotao), combine this warning with `make inference_lib_dist` command.
320-
message(WARNING "On inference mode, will take place some specific optimization. Turn on the ON_INFER flag when building inference_lib only.")
321-
endif()

paddle/fluid/operators/math/softmax.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ namespace paddle {
1919
namespace operators {
2020
namespace math {
2121

22-
template <typename DeviceContext, typename T, bool is_test>
22+
template <typename DeviceContext, typename T, bool is_test,
23+
typename Enable = void>
2324
class SoftmaxFunctor {
2425
public:
2526
void operator()(const DeviceContext& context, const framework::Tensor* X,

paddle/fluid/operators/math/softmax_impl.h

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include "paddle/fluid/framework/eigen.h"
1717
#include "paddle/fluid/framework/tensor.h"
1818

19+
#include "paddle/fluid/operators/math/blas.h"
1920
namespace paddle {
2021
namespace operators {
2122
namespace math {
@@ -32,8 +33,8 @@ struct ValueClip {
3233
}
3334
};
3435

35-
template <typename DeviceContext, typename T, bool is_test>
36-
void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
36+
template <typename DeviceContext, typename T, bool is_test, typename Enable>
37+
void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
3738
const DeviceContext& context, const framework::Tensor* X,
3839
framework::Tensor* Y) {
3940
auto logits = EigenMatrix<T>::From(*X);
@@ -65,36 +66,46 @@ void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
6566
.broadcast(one_by_class));
6667
}
6768

68-
template <typename DeviceContext, typename T>
69-
class SoftmaxFunctor<DeviceContext, T, true> {
69+
template <class DeviceContext>
70+
using enable_if_CPU = typename std::enable_if<
71+
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type;
72+
73+
template <typename DeviceContext>
74+
class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
7075
void operator()(const DeviceContext& context, const framework::Tensor* X,
7176
framework::Tensor* Y) {
72-
auto logits = EigenMatrix<T>::From(*X);
73-
auto softmax = EigenMatrix<T>::From(*Y);
74-
77+
auto in_dims = X->dims();
78+
auto out_dims = Y->dims();
79+
const float* in_data = X->data<float>();
80+
float* out_data = Y->data<float>();
7581
const int kBatchDim = 0;
7682
const int kClassDim = 1;
77-
78-
const int batch_size = logits.dimension(kBatchDim);
79-
const int num_classes = logits.dimension(kClassDim);
80-
81-
Eigen::DSizes<int, 1> along_class(kClassDim);
82-
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
83-
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
84-
85-
auto shifted_logits = (logits -
86-
logits.maximum(along_class)
87-
.eval()
88-
.reshape(batch_by_one)
89-
.broadcast(one_by_class));
90-
91-
softmax.device(*context.eigen_device()) = shifted_logits.exp();
92-
softmax.device(*context.eigen_device()) = (softmax *
93-
softmax.sum(along_class)
94-
.inverse()
95-
.eval()
96-
.reshape(batch_by_one)
97-
.broadcast(one_by_class));
83+
// 2D data. Batch x C
84+
const int batch_size = in_dims[kBatchDim];
85+
const int num_classes = in_dims[kClassDim];
86+
std::vector<float> entities(batch_size);
87+
auto blas = math::GetBlas<DeviceContext, float>(context);
88+
for (int n = 0; n < batch_size; ++n) {
89+
entities[n] = in_data[n * num_classes];
90+
for (int c = 1; c < num_classes; ++c) {
91+
entities[n] = in_data[n * num_classes + c] > entities[n]
92+
? in_data[n * num_classes + c]
93+
: entities[n];
94+
}
95+
for (int c = 0; c < num_classes; ++c) {
96+
out_data[n * num_classes + c] =
97+
in_data[n * num_classes + c] - entities[n];
98+
}
99+
}
100+
101+
blas.VEXP(num_classes * batch_size, out_data, out_data);
102+
for (int n = 0; n < batch_size; ++n) {
103+
entities[n] = out_data[n * num_classes];
104+
for (int c = 1; c < num_classes; ++c) {
105+
entities[n] += out_data[n * num_classes + c];
106+
}
107+
blas.SCAL(num_classes, 1.0f / entities[n], &out_data[n * num_classes]);
108+
}
98109
}
99110
};
100111

paddle/fluid/operators/softmax_op.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ class SoftmaxKernel : public framework::OpKernel<T> {
3535
Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1);
3636
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
3737

38-
#ifdef ON_INFER
39-
math::SoftmaxFunctor<DeviceContext, T, true>()(
38+
#ifdef PADDLE_ON_INFERENCE
39+
math::SoftmaxFunctor<
40+
DeviceContext, T,
41+
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>()(
4042
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
4143
#else
4244
math::SoftmaxFunctor<DeviceContext, T, false>()(

0 commit comments

Comments
 (0)