Skip to content

Commit 5447463

Browse files
authored
Merge pull request #16057 from heavengate/softmax_axis
Add attr 'axis' for softmax
2 parents 63ac947 + 3e35238 commit 5447463

File tree

22 files changed

+413
-78
lines changed

22 files changed

+413
-78
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ paddle.fluid.layers.conv2d (ArgSpec(args=['input', 'num_filters', 'filter_size',
9595
paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '37042620f9bd3a2da6e5d3138b2f724b'))
9696
paddle.fluid.layers.sequence_pool (ArgSpec(args=['input', 'pool_type', 'is_test'], varargs=None, keywords=None, defaults=(False,)), ('document', 'a194fb80614023f543df3949fbd0d0b8'))
9797
paddle.fluid.layers.sequence_softmax (ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', '19ef6f9cdd27feac8a1ae060f19c10b4'))
98-
paddle.fluid.layers.softmax (ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', 'f19dd380864e61134ce3814e4be0de4b'))
98+
paddle.fluid.layers.softmax (ArgSpec(args=['input', 'use_cudnn', 'name', 'axis'], varargs=None, keywords=None, defaults=(False, None, -1)), ('document', '59b1c6bf2f0fa9dc649c85fef3a3b2ea'))
9999
paddle.fluid.layers.pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', 'bbd84e855e660cd1084bb71a2fd0cdaa'))
100100
paddle.fluid.layers.pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', '043de7333b79ee0ac55053c14ed81625'))
101101
paddle.fluid.layers.adaptive_pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '859b887174d06f361658f69cb7c06d95'))

paddle/fluid/operators/jit/benchmark.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ void BenchKernelSoftmax() {
386386
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
387387
const T* x_data = x.data<T>();
388388
T* y_data = y.mutable_data<T>(PlaceType());
389-
BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs);
389+
BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs, 1);
390390
}
391391
}
392392
}

paddle/fluid/operators/jit/helper.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ const char* to_string(KernelType kt) {
3434
ONE_CASE(kVAddRelu);
3535
ONE_CASE(kVSub);
3636
ONE_CASE(kVScal);
37+
ONE_CASE(kStrideScal);
3738
ONE_CASE(kVAddBias);
3839
ONE_CASE(kVRelu);
3940
ONE_CASE(kVBroadcast);
@@ -55,6 +56,7 @@ const char* to_string(KernelType kt) {
5556
ONE_CASE(kMatMul);
5657
ONE_CASE(kHMax);
5758
ONE_CASE(kHSum);
59+
ONE_CASE(kStrideASum);
5860
ONE_CASE(kSoftmax);
5961
ONE_CASE(kEmbSeqPool);
6062
ONE_CASE(kSgd);

paddle/fluid/operators/jit/kernel_base.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ typedef enum {
3838
kNCHW16CMulNC,
3939
kSeqPool,
4040
kSoftmax,
41+
kStrideASum,
42+
kStrideScal,
4143
kVAdd,
4244
kVAddBias,
4345
kVAddRelu,
@@ -74,6 +76,14 @@ struct XYZNTuple {
7476
template <typename T>
7577
struct AXYNTuple : public XYZNTuple<T> {};
7678

79+
// a, x, y, n, stride
80+
template <typename T>
81+
struct AXYNSTuple {
82+
typedef T data_type;
83+
typedef int attr_type;
84+
typedef void (*func_type)(const T*, const T*, T*, int, int);
85+
};
86+
7787
// x, y, n
7888
template <typename T>
7989
struct XYNTuple {
@@ -86,6 +96,14 @@ struct XYNTuple {
8696
template <typename T>
8797
struct XRNTuple : public XYNTuple<T> {};
8898

99+
// x, returned value, n, stride
100+
template <typename T>
101+
struct XRNSTuple {
102+
typedef T data_type;
103+
typedef int attr_type;
104+
typedef void (*func_type)(const T*, T*, int, int);
105+
};
106+
89107
#define DECLARE_KERNELTUPLE(kernel_tuple, type) \
90108
template <typename T> \
91109
struct type##Tuple : public kernel_tuple<T> { \
@@ -101,6 +119,8 @@ DECLARE_KERNELTUPLE(XYZNTuple, VSub);
101119
DECLARE_KERNELTUPLE(AXYNTuple, VScal);
102120
DECLARE_KERNELTUPLE(AXYNTuple, VAddBias);
103121

122+
DECLARE_KERNELTUPLE(AXYNSTuple, StrideScal);
123+
104124
DECLARE_KERNELTUPLE(XYNTuple, VRelu);
105125
DECLARE_KERNELTUPLE(XYNTuple, VIdentity);
106126
DECLARE_KERNELTUPLE(XYNTuple, VSquare);
@@ -112,6 +132,8 @@ DECLARE_KERNELTUPLE(XYNTuple, VCopy);
112132
DECLARE_KERNELTUPLE(XRNTuple, HMax);
113133
DECLARE_KERNELTUPLE(XRNTuple, HSum);
114134

135+
DECLARE_KERNELTUPLE(XRNSTuple, StrideASum);
136+
115137
typedef struct {
116138
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
117139
const void* ct_1;
@@ -285,7 +307,7 @@ struct SoftmaxTuple {
285307
static constexpr KernelType kernel_type = kSoftmax;
286308
typedef T data_type;
287309
typedef int attr_type;
288-
typedef void (*func_type)(const T*, T*, int, int);
310+
typedef void (*func_type)(const T*, T*, int, int, int);
289311
};
290312

291313
// nChw16c = nChw16c .* NC

paddle/fluid/operators/jit/more/mix/mix.cc

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,15 @@ void VTanh(const T* x, T* y, int n) {
5050
compute_addbias(&b, y, y, n);
5151
}
5252

53-
void Softmax(const T* x, T* y, int n, int bs) {
53+
// remain is the product of dimension shapes after the axis dimension
54+
void Softmax(const T* x, T* y, int n, int bs, int remain) {
5455
auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
5556
auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
5657
auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
58+
auto compute_strideasum =
59+
KernelFuncs<StrideASumTuple<T>, CPUPlace>::Cache().At(n);
60+
auto compute_stridescal =
61+
KernelFuncs<StrideScalTuple<T>, CPUPlace>::Cache().At(n);
5762
auto compute_vaddbias =
5863
KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
5964
auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
@@ -64,9 +69,17 @@ void Softmax(const T* x, T* y, int n, int bs) {
6469
scalar = static_cast<T>(0) - scalar;
6570
compute_vaddbias(&scalar, x, y, n); // x - max
6671
compute_vexp(y, y, n);
67-
compute_hsum(y, &scalar, n);
68-
scalar = static_cast<T>(1) / scalar;
69-
compute_vscal(&scalar, y, y, n);
72+
if (remain == 1) {
73+
compute_hsum(y, &scalar, n);
74+
scalar = static_cast<T>(1) / scalar;
75+
compute_vscal(&scalar, y, y, n);
76+
} else {
77+
for (int j = 0; j < remain; ++j) {
78+
compute_strideasum(&y[j], &scalar, n, remain);
79+
scalar = static_cast<T>(1) / scalar;
80+
compute_stridescal(&scalar, &y[j], &y[j], n, remain);
81+
}
82+
}
7083
x += n;
7184
y += n;
7285
}

paddle/fluid/operators/jit/more/mix/mix.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using T = float;
2626

2727
void VSigmoid(const T* x, T* y, int n);
2828
void VTanh(const T* x, T* y, int n);
29-
void Softmax(const T* x, T* y, int n, int bs);
29+
void Softmax(const T* x, T* y, int n, int bs, int remain);
3030

3131
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr);
3232
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr);

paddle/fluid/operators/jit/more/mkl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ USE_JITKERNEL_MORE(kMatMul, mkl)
77
USE_JITKERNEL_MORE(kVMul, mkl)
88
USE_JITKERNEL_MORE(kVAdd, mkl)
99
USE_JITKERNEL_MORE(kVScal, mkl)
10+
USE_JITKERNEL_MORE(kStrideScal, mkl)
1011
USE_JITKERNEL_MORE(kVExp, mkl)
1112
USE_JITKERNEL_MORE(kVSquare, mkl)
1213
USE_JITKERNEL_MORE(kVCopy, mkl)

paddle/fluid/operators/jit/more/mkl/mkl.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,26 @@ void VScal<double>(const double* a, const double* x, double* y, int n) {
7878
}
7979
}
8080

81+
template <>
82+
void StrideScal<float>(const float* a, const float* x, float* y, int n,
83+
int stride) {
84+
if (x == y) {
85+
platform::dynload::cblas_sscal(n / stride, *a, y, stride);
86+
} else {
87+
refer::StrideScal<float>(a, x, y, n, stride);
88+
}
89+
}
90+
91+
template <>
92+
void StrideScal<double>(const double* a, const double* x, double* y, int n,
93+
int stride) {
94+
if (x == y) {
95+
platform::dynload::cblas_dscal(n / stride, *a, y, stride);
96+
} else {
97+
refer::StrideScal<double>(a, x, y, n, stride);
98+
}
99+
}
100+
81101
template <>
82102
void VExp<float>(const float* x, float* y, int n) {
83103
platform::dynload::vsExp(n, x, y);
@@ -128,6 +148,16 @@ void ASum<double>(const double* x, double* res, int n) {
128148
res[0] = platform::dynload::cblas_dasum(n, x, 1);
129149
}
130150

151+
template <>
152+
void StrideASum<float>(const float* x, float* res, int n, int stride) {
153+
res[0] = platform::dynload::cblas_sasum(n / stride, x, stride);
154+
}
155+
156+
template <>
157+
void StrideASum<double>(const double* x, double* res, int n, int stride) {
158+
res[0] = platform::dynload::cblas_dasum(n / stride, x, stride);
159+
}
160+
131161
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
132162
template <>
133163
bool VMulKernel<float>::CanBeUsed(const int& d) const {
@@ -144,6 +174,11 @@ bool VScalKernel<float>::CanBeUsed(const int& d) const {
144174
return platform::MayIUse(platform::avx512f) && d > 512;
145175
}
146176

177+
template <>
178+
bool StrideScalKernel<float>::CanBeUsed(const int& d) const {
179+
return true;
180+
}
181+
147182
template <>
148183
bool VExpKernel<float>::CanBeUsed(const int& d) const {
149184
return d > 7;
@@ -235,6 +270,7 @@ bool SoftmaxKernel<float>::CanBeUsed(const int& d) const {
235270
AWALYS_USE_ME_WITH_DOUBLE(VMul);
236271
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
237272
AWALYS_USE_ME_WITH_DOUBLE(VScal);
273+
AWALYS_USE_ME_WITH_DOUBLE(StrideScal);
238274
AWALYS_USE_ME_WITH_DOUBLE(VExp);
239275
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
240276
AWALYS_USE_ME_WITH_DOUBLE(VTanh);
@@ -259,6 +295,7 @@ REGISTER_MKL_KERNEL(MatMul);
259295
REGISTER_MKL_KERNEL(VMul);
260296
REGISTER_MKL_KERNEL(VAdd);
261297
REGISTER_MKL_KERNEL(VScal);
298+
REGISTER_MKL_KERNEL(StrideScal);
262299
REGISTER_MKL_KERNEL(VExp);
263300
REGISTER_MKL_KERNEL(VSquare);
264301
REGISTER_MKL_KERNEL(VCopy);

paddle/fluid/operators/jit/more/mkl/mkl.h

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,14 @@ template <typename T>
129129
void ASum(const T* x, T* res, int n);
130130

131131
template <typename T>
132-
void Softmax(const T* x, T* y, int n, int bs) {
132+
void StrideASum(const T* x, T* res, int n, int stride);
133+
134+
template <typename T>
135+
void StrideScal(const T* a, const T* x, T* y, int n, int stride);
136+
137+
// remain is the product of dimension shapes after the axis dimension
138+
template <typename T>
139+
void Softmax(const T* x, T* y, int n, int bs, int remain = 1) {
133140
std::vector<T> entities(bs);
134141
for (int i = 0; i < bs; ++i) {
135142
entities[i] = x[i * n];
@@ -143,9 +150,17 @@ void Softmax(const T* x, T* y, int n, int bs) {
143150
VExp(y, y, n * bs);
144151
for (int i = 0; i < bs; ++i) {
145152
T sum;
146-
ASum(&y[i * n], &sum, n);
147-
sum = static_cast<T>(1) / sum;
148-
VScal(&sum, &y[i * n], &y[i * n], n);
153+
if (remain == 1) {
154+
ASum(&y[i * n], &sum, n);
155+
sum = static_cast<T>(1) / sum;
156+
VScal(&sum, &y[i * n], &y[i * n], n);
157+
} else {
158+
for (int j = 0; j < remain; ++j) {
159+
StrideASum(&y[i * n + j], &sum, n, remain);
160+
sum = static_cast<T>(1) / sum;
161+
StrideScal(&sum, &y[i * n + j], &y[i * n + j], n, remain);
162+
}
163+
}
149164
}
150165
}
151166

@@ -193,6 +208,7 @@ DECLARE_MKL_KERNEL(VAdd);
193208

194209
// AXYN
195210
DECLARE_MKL_KERNEL(VScal);
211+
DECLARE_MKL_KERNEL(StrideScal);
196212

197213
// XYN
198214
DECLARE_MKL_KERNEL(VExp);

paddle/fluid/operators/jit/refer/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ USE_JITKERNEL_REFER(kVAdd)
1212
USE_JITKERNEL_REFER(kVAddRelu)
1313
USE_JITKERNEL_REFER(kVSub)
1414
USE_JITKERNEL_REFER(kVScal)
15+
USE_JITKERNEL_REFER(kStrideScal)
1516
USE_JITKERNEL_REFER(kVAddBias)
1617
USE_JITKERNEL_REFER(kVCopy)
1718
USE_JITKERNEL_REFER(kVRelu)
@@ -32,6 +33,7 @@ USE_JITKERNEL_REFER(kMatMul)
3233
USE_JITKERNEL_REFER(kVSquare)
3334
USE_JITKERNEL_REFER(kHSum)
3435
USE_JITKERNEL_REFER(kHMax)
36+
USE_JITKERNEL_REFER(kStrideASum)
3537
USE_JITKERNEL_REFER(kSoftmax)
3638
USE_JITKERNEL_REFER(kEmbSeqPool)
3739
USE_JITKERNEL_REFER(kSgd)

0 commit comments

Comments
 (0)