Skip to content

Commit ea47685

Browse files
authored
Merge pull request #14646 from jczaja/prv-softmax-mkl-sasum
Softmax for inference MKL further changes
2 parents 2238b96 + 48e1b97 commit ea47685

File tree

5 files changed

+39
-8
lines changed

5 files changed

+39
-8
lines changed

paddle/fluid/operators/math/blas.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ class Blas {
168168
template <typename T>
169169
void SCAL(int n, const T a, T* x) const;
170170

171+
template <typename T>
172+
T ASUM(int n, T* x, int inc) const;
173+
171174
template <typename T>
172175
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
173176
int K, T alpha, const T* A, const T* B, T beta, T* C,
@@ -269,6 +272,11 @@ class BlasT : private Blas<DeviceContext> {
269272
Base()->template SCAL<T>(args...);
270273
}
271274

275+
template <typename... ARGS>
276+
T ASUM(ARGS... args) const {
277+
return Base()->template ASUM<T>(args...);
278+
}
279+
272280
template <typename... ARGS>
273281
void BatchedGEMM(ARGS... args) const {
274282
Base()->template BatchedGEMM<T>(args...);

paddle/fluid/operators/math/blas_impl.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ struct CBlas<float> {
8484
platform::dynload::cblas_sscal(args...);
8585
}
8686

87+
template <typename... ARGS>
88+
static float ASUM(ARGS... args) {
89+
return platform::dynload::cblas_sasum(args...);
90+
}
91+
8792
template <typename... ARGS>
8893
static void GEMM_BATCH(ARGS... args) {
8994
platform::dynload::cblas_sgemm_batch(args...);
@@ -174,6 +179,11 @@ struct CBlas<double> {
174179
platform::dynload::cblas_dscal(args...);
175180
}
176181

182+
template <typename... ARGS>
183+
static double ASUM(ARGS... args) {
184+
return platform::dynload::cblas_dasum(args...);
185+
}
186+
177187
template <typename... ARGS>
178188
static void GEMM_BATCH(ARGS... args) {
179189
platform::dynload::cblas_dgemm_batch(args...);
@@ -268,6 +278,7 @@ struct CBlas<platform::float16> {
268278
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
269279
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
270280
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
281+
static void ASUM(...) { PADDLE_THROW("float16 ASUM not supported on CPU"); };
271282
#ifdef PADDLE_WITH_MKLML
272283
static void GEMM_BATCH(...) {
273284
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
@@ -476,6 +487,21 @@ void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
476487
#endif
477488
}
478489

490+
template <>
491+
template <typename T>
492+
T Blas<platform::CPUDeviceContext>::ASUM(int n, T *x, int inc) const {
493+
auto sum = static_cast<T>(0.0);
494+
#ifdef PADDLE_WITH_MKLML
495+
sum = CBlas<T>::ASUM(n, x, inc);
496+
#else
497+
// TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum
498+
for (int c = 0; c < n; ++c) {
499+
sum += x[c];
500+
}
501+
#endif
502+
return sum;
503+
}
504+
479505
template <>
480506
template <typename T>
481507
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,

paddle/fluid/operators/math/softmax_impl.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,8 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
100100

101101
blas.VEXP(num_classes * batch_size, out_data, out_data);
102102
for (int n = 0; n < batch_size; ++n) {
103-
entities[n] = out_data[n * num_classes];
104-
for (int c = 1; c < num_classes; ++c) {
105-
entities[n] += out_data[n * num_classes + c];
106-
}
107-
blas.SCAL(num_classes, 1.0f / entities[n], &out_data[n * num_classes]);
103+
auto sum = blas.ASUM(num_classes, &out_data[n * num_classes], 1);
104+
blas.SCAL(num_classes, 1.0f / sum, &out_data[n * num_classes]);
108105
}
109106
}
110107
};

paddle/fluid/operators/softmax_op.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ class SoftmaxKernel : public framework::OpKernel<T> {
3636
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
3737

3838
#ifdef PADDLE_ON_INFERENCE
39-
math::SoftmaxFunctor<
40-
DeviceContext, T,
41-
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>()(
39+
math::SoftmaxFunctor<DeviceContext, T, true>()(
4240
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
4341
#else
4442
math::SoftmaxFunctor<DeviceContext, T, false>()(

paddle/fluid/platform/dynload/mklml.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ extern void* mklml_dso_handle;
6868
__macro(cblas_dgemm_batch); \
6969
__macro(cblas_sdot); \
7070
__macro(cblas_ddot); \
71+
__macro(cblas_sasum); \
72+
__macro(cblas_dasum); \
7173
__macro(cblas_sscal); \
7274
__macro(cblas_dscal); \
7375
__macro(vsAdd); \

0 commit comments

Comments
 (0)