Skip to content

Commit 7383eef

Browse files
committed
add softmax mix and mkl code
test=develop
1 parent 5094568 commit 7383eef

File tree

6 files changed

+74
-0
lines changed

6 files changed

+74
-0
lines changed

paddle/fluid/operators/jit/more/mix/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ USE_JITKERNEL_MORE(kLSTMC1H1, mix)
1212
USE_JITKERNEL_MORE(kGRUH1, mix)
1313
USE_JITKERNEL_MORE(kGRUHtPart1, mix)
1414
USE_JITKERNEL_MORE(kGRUHtPart2, mix)
15+
USE_JITKERNEL_MORE(kSoftmax, mix)

paddle/fluid/operators/jit/more/mix/mix.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@ void VTanh(const T* x, T* y, int n) {
4848
compute_addbias(&b, y, y, n);
4949
}
5050

51+
void Softmax(const T* x, T* y, int n, int bs) {
52+
auto compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n);
53+
auto compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
54+
auto compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
55+
auto compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
56+
auto compute_vexp =
57+
Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
58+
for (int i = 0; i < bs; ++i) {
59+
T scalar;
60+
compute_hmax(x, &scalar, n);
61+
scalar = static_cast<T>(0) - scalar;
62+
compute_vaddbias(&scalar, x, y, n); // x - max
63+
compute_vexp(y, y, n);
64+
compute_hsum(y, &scalar, n);
65+
scalar = static_cast<T>(1) / scalar;
66+
compute_vscal(&scalar, y, y, n);
67+
x += n;
68+
y += n;
69+
}
70+
}
71+
5172
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
5273
if (type == kVSigmoid) {
5374
return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d);
@@ -184,6 +205,8 @@ bool VSigmoidKernel::UseMe(const int& d) const { return true; }
184205

185206
bool VTanhKernel::UseMe(const int& d) const { return true; }
186207

208+
bool SoftmaxKernel::UseMe(const int& d) const { return true; }
209+
187210
bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; }
188211

189212
bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; }
@@ -207,6 +230,7 @@ namespace mix = paddle::operators::jit::more::mix;
207230

208231
REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid);
209232
REGISTER_MORE_KERNEL(kVTanh, VTanh);
233+
REGISTER_MORE_KERNEL(kSoftmax, Softmax);
210234
REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt);
211235
REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1);
212236
REGISTER_MORE_KERNEL(kGRUH1, GRUH1);

paddle/fluid/operators/jit/more/mix/mix.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ using T = float;
2626

2727
void VSigmoid(const T* x, T* y, int n);
2828
void VTanh(const T* x, T* y, int n);
29+
void Softmax(const T* x, T* y, int n, int bs);
2930

3031
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr);
3132
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr);
@@ -45,6 +46,9 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
4546
DECLARE_MORE_KERNEL(VSigmoid, XYNTuples);
4647
DECLARE_MORE_KERNEL(VTanh, XYNTuples);
4748

49+
// XRN
50+
DECLARE_MORE_KERNEL(Softmax, SoftmaxTuples);
51+
4852
DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples);
4953
DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples);
5054

paddle/fluid/operators/jit/more/mkl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ USE_JITKERNEL_MORE(kVSquare, mkl)
1212
USE_JITKERNEL_MORE(kVSigmoid, mkl)
1313
USE_JITKERNEL_MORE(kVTanh, mkl)
1414
USE_JITKERNEL_MORE(kSeqPool, mkl)
15+
USE_JITKERNEL_MORE(kSoftmax, mkl)

paddle/fluid/operators/jit/more/mkl/mkl.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,16 @@ void VAXPY<double>(double a, const double* x, double* y, int n) {
116116
platform::dynload::cblas_daxpy(n, a, x, 1, y, 1);
117117
}
118118

119+
template <>
120+
void ASum<float>(const float* x, float* res, int n) {
121+
res[0] = platform::dynload::cblas_sasum(n, x, 1);
122+
}
123+
124+
template <>
125+
void ASum<double>(const double* x, double* res, int n) {
126+
res[0] = platform::dynload::cblas_dasum(n, x, 1);
127+
}
128+
119129
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
120130
template <>
121131
bool MatMulKernel<float>::UseMe(const int& d) const {
@@ -167,6 +177,11 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
167177
return true;
168178
}
169179

180+
template <>
181+
bool SoftmaxKernel<float>::UseMe(const int& d) const {
182+
return true;
183+
}
184+
170185
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
171186
template <> \
172187
bool func##Kernel<double>::UseMe(const int& d) const { \
@@ -181,6 +196,7 @@ AWALYS_USE_ME_WITH_DOUBLE(VExp);
181196
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
182197
AWALYS_USE_ME_WITH_DOUBLE(VTanh);
183198
AWALYS_USE_ME_WITH_DOUBLE(VSquare);
199+
AWALYS_USE_ME_WITH_DOUBLE(Softmax);
184200

185201
#undef AWALYS_USE_ME_WITH_DOUBLE
186202
} // namespace mkl
@@ -204,5 +220,6 @@ REGISTER_MKL_KERNEL(kVSquare, VSquare);
204220
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
205221
REGISTER_MKL_KERNEL(kVTanh, VTanh);
206222
REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
223+
REGISTER_MKL_KERNEL(kSoftmax, Softmax);
207224

208225
#undef REGISTER_MKL_KERNEL

paddle/fluid/operators/jit/more/mkl/mkl.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <cmath>
1818
#include <type_traits>
19+
#include <vector>
1920
#include "paddle/fluid/operators/jit/kernel_base.h"
2021

2122
namespace paddle {
@@ -90,6 +91,30 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
9091
}
9192
}
9293

94+
template <typename T>
95+
void ASum(const T* x, T* res, int n);
96+
97+
template <typename T>
98+
void Softmax(const T* x, T* y, int n, int bs) {
99+
std::vector<T> entities(bs);
100+
for (int i = 0; i < bs; ++i) {
101+
entities[i] = x[i * n];
102+
for (int c = 1; c < n; ++c) {
103+
entities[i] = x[i * n + c] > entities[i] ? x[i * n + c] : entities[i];
104+
}
105+
for (int c = 0; c < n; ++c) {
106+
y[i * n + c] = x[i * n + c] - entities[i];
107+
}
108+
}
109+
VExp(y, y, n * bs);
110+
for (int i = 0; i < bs; ++i) {
111+
T sum;
112+
ASum(&y[i * n], &sum, n);
113+
sum = static_cast<T>(1) / sum;
114+
VScal(&sum, &y[i * n], &y[i * n], n);
115+
}
116+
}
117+
93118
#define DECLARE_MKL_KERNEL(name, tuples) \
94119
template <typename T> \
95120
class name##Kernel : public KernelMore<tuples<T>> { \
@@ -117,6 +142,8 @@ DECLARE_MKL_KERNEL(VSquare, XYNTuples);
117142

118143
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
119144

145+
DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples);
146+
120147
#undef DECLARE_MKL_KERNEL
121148

122149
} // namespace mkl

0 commit comments

Comments
 (0)