Skip to content

Commit d59f733

Browse files
committed
refine softmax and use with cache
test=develop
1 parent 7383eef commit d59f733

File tree

7 files changed

+102
-34
lines changed

7 files changed

+102
-34
lines changed

paddle/fluid/operators/jit/benchmark.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ void BenchAXYNKernel() {
187187
RandomVec<T>(d, x_data);
188188
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), y_data,
189189
d);
190+
// test inplace
191+
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), x_data,
192+
d);
190193
}
191194
}
192195

paddle/fluid/operators/jit/gen/act.cc

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ void VActJitCode::genCode() {
8181
#define DECLARE_ACT_CREATOR(name) \
8282
class name##Creator : public JitCodeCreator<int> { \
8383
public: \
84-
bool UseMe(const int& attr) const override { \
85-
return platform::MayIUse(platform::avx); \
86-
} \
84+
bool UseMe(const int& attr) const override; \
8785
size_t CodeSize(const int& d) const override; \
8886
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
8987
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
@@ -98,6 +96,30 @@ DECLARE_ACT_CREATOR(VSigmoid);
9896
DECLARE_ACT_CREATOR(VTanh);
9997

10098
// TODO(TJ): tuning use me
99+
bool VReluCreator::UseMe(const int& d) const {
100+
return platform::MayIUse(platform::avx);
101+
}
102+
103+
bool VSquareCreator::UseMe(const int& d) const {
104+
return platform::MayIUse(platform::avx);
105+
}
106+
107+
bool VIdentityCreator::UseMe(const int& d) const {
108+
return platform::MayIUse(platform::avx);
109+
}
110+
111+
bool VExpCreator::UseMe(const int& d) const {
112+
return platform::MayIUse(platform::avx) && d < 32;
113+
}
114+
115+
bool VSigmoidCreator::UseMe(const int& d) const {
116+
return platform::MayIUse(platform::avx);
117+
}
118+
119+
bool VTanhCreator::UseMe(const int& d) const {
120+
return platform::MayIUse(platform::avx);
121+
}
122+
101123
size_t VReluCreator::CodeSize(const int& d) const {
102124
return 96 /* init size */ +
103125
(d / YMM_FLOAT_BLOCK + 3) * 4 /* instructions */ *

paddle/fluid/operators/jit/helper.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,28 @@ typename KernelTuples::func_type Get(
118118
return GetRefer<KT, KernelTuples>();
119119
}
120120

121+
template <KernelType KT, typename KernelTuples>
122+
class KernelFuncsCache {
123+
public:
124+
KernelFuncsCache() = default;
125+
static KernelFuncsCache& Instance() {
126+
static thread_local KernelFuncsCache<KT, KernelTuples> g_func_cache;
127+
return g_func_cache;
128+
}
129+
130+
bool Has(int key) const { return funcs_.find(key) != funcs_.end(); }
131+
132+
typename KernelTuples::func_type At(int key) { return funcs_.at(key); }
133+
134+
void Insert(int key, typename KernelTuples::func_type func) {
135+
funcs_.emplace(key, func);
136+
}
137+
138+
private:
139+
std::unordered_map<int, typename KernelTuples::func_type> funcs_;
140+
DISABLE_COPY_AND_ASSIGN(KernelFuncsCache);
141+
};
142+
121143
const char* to_string(KernelType kt);
122144
const char* to_string(SeqPoolType kt);
123145

paddle/fluid/operators/jit/more/mix/mix.cc

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,50 @@ void VTanh(const T* x, T* y, int n) {
4949
}
5050

5151
void Softmax(const T* x, T* y, int n, int bs) {
52-
auto compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n);
53-
auto compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
54-
auto compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
55-
auto compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
56-
auto compute_vexp =
57-
Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
52+
typename XRNTuples<T>::func_type compute_hmax{nullptr};
53+
typename XRNTuples<T>::func_type compute_hsum{nullptr};
54+
typename AXYNTuples<T>::func_type compute_vscal{nullptr};
55+
typename AXYNTuples<T>::func_type compute_vaddbias{nullptr};
56+
typename XYNTuples<T>::func_type compute_vexp{nullptr};
57+
58+
if (!KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Has(n)) {
59+
compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n);
60+
KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Insert(n, compute_hmax);
61+
} else {
62+
compute_hmax = KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().At(n);
63+
}
64+
65+
if (!KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Has(n)) {
66+
compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
67+
KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Insert(n, compute_hsum);
68+
} else {
69+
compute_hsum = KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().At(n);
70+
}
71+
72+
if (!KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Has(n)) {
73+
compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
74+
KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Insert(n,
75+
compute_vscal);
76+
} else {
77+
compute_vscal = KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().At(n);
78+
}
79+
80+
if (!KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Has(n)) {
81+
compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
82+
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Insert(
83+
n, compute_vaddbias);
84+
} else {
85+
compute_vaddbias =
86+
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().At(n);
87+
}
88+
89+
if (!KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Has(n)) {
90+
compute_vexp = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
91+
KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Insert(n, compute_vexp);
92+
} else {
93+
compute_vexp = KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().At(n);
94+
}
95+
5896
for (int i = 0; i < bs; ++i) {
5997
T scalar;
6098
compute_hmax(x, &scalar, n);

paddle/fluid/operators/jit/more/mkl/mkl.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
179179

180180
template <>
181181
bool SoftmaxKernel<float>::UseMe(const int& d) const {
182-
return true;
182+
// tuned on avx2
183+
return platform::MayIUse(platform::avx) && d < 60;
183184
}
184185

185186
#define AWALYS_USE_ME_WITH_DOUBLE(func) \

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ math_library(sequence2batch)
5353
math_library(sequence_padding)
5454
math_library(sequence_pooling DEPS math_function jit_kernel_helper)
5555
math_library(sequence_scale)
56-
math_library(softmax DEPS math_function)
56+
math_library(softmax DEPS math_function jit_kernel_helper)
5757
math_library(beam_search DEPS math_function)
5858

5959
math_library(matrix_bit_code)

paddle/fluid/operators/math/softmax_impl.h

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ limitations under the License. */
1616
#include <vector>
1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/framework/tensor.h"
19+
#include "paddle/fluid/operators/jit/kernels.h"
1920

20-
#include "paddle/fluid/operators/math/blas.h"
2121
namespace paddle {
2222
namespace operators {
2323
namespace math {
@@ -81,28 +81,10 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
8181
const int kBatchDim = 0;
8282
const int kClassDim = 1;
8383
// 2D data. Batch x C
84-
const int batch_size = in_dims[kBatchDim];
85-
const int num_classes = in_dims[kClassDim];
86-
std::vector<float> entities(batch_size);
87-
auto blas = math::GetBlas<DeviceContext, float>(context);
88-
for (int n = 0; n < batch_size; ++n) {
89-
entities[n] = in_data[n * num_classes];
90-
for (int c = 1; c < num_classes; ++c) {
91-
entities[n] = in_data[n * num_classes + c] > entities[n]
92-
? in_data[n * num_classes + c]
93-
: entities[n];
94-
}
95-
for (int c = 0; c < num_classes; ++c) {
96-
out_data[n * num_classes + c] =
97-
in_data[n * num_classes + c] - entities[n];
98-
}
99-
}
100-
101-
blas.VEXP(num_classes * batch_size, out_data, out_data);
102-
for (int n = 0; n < batch_size; ++n) {
103-
auto sum = blas.ASUM(num_classes, &out_data[n * num_classes], 1);
104-
blas.SCAL(num_classes, 1.0f / sum, &out_data[n * num_classes]);
105-
}
84+
auto compute_softmax =
85+
jit::Get<jit::kSoftmax, jit::SoftmaxTuples<float>, platform::CPUPlace>(
86+
in_dims[kClassDim]);
87+
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
10688
}
10789
};
10890

0 commit comments

Comments
 (0)