Skip to content

Commit a6a1a92

Browse files
authored
Merge pull request #15586 from tensor-tang/jit/cache
refine bert
2 parents e887d71 + 2b0811c commit a6a1a92

File tree

8 files changed

+41
-59
lines changed

8 files changed

+41
-59
lines changed

paddle/fluid/operators/jit/benchmark.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ std::vector<int> TestSizes() {
9393
template <typename KernelTuples, typename... Args>
9494
struct BenchFunc {
9595
// return this function avg time
96+
// TODO(TJ): clear cache every time
9697
double operator()(const typename KernelTuples::func_type tgt, Args... args) {
9798
for (int i = 0; i < FLAGS_burning; ++i) {
9899
tgt(args...);
@@ -172,6 +173,9 @@ void BenchXYZNKernel() {
172173
RandomVec<T>(d, y_data);
173174
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(),
174175
y.data<T>(), z_data, d);
176+
// test inplace
177+
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), z_data,
178+
z_data, d);
175179
}
176180
}
177181

paddle/fluid/operators/jit/gen/blas.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
155155
class name##Creator : public JitCodeCreator<int> { \
156156
public: \
157157
bool UseMe(const int& attr) const override { \
158-
return platform::MayIUse(platform::avx); \
158+
return platform::MayIUse(platform::avx) && attr <= 1024; \
159159
} \
160160
size_t CodeSize(const int& d) const override { \
161161
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \

paddle/fluid/operators/jit/gen/blas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class VXXJitCode : public JitCode {
6161
base += "_Vec";
6262
}
6363
base += (with_relu_ ? "_Relu" : "");
64+
base += "_D" + std::to_string(num_);
6465
return base.c_str();
6566
}
6667
void genCode() override;

paddle/fluid/operators/jit/helper.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,26 +118,33 @@ typename KernelTuples::func_type Get(
118118
return GetRefer<KT, KernelTuples>();
119119
}
120120

121-
template <KernelType KT, typename KernelTuples>
122-
class KernelFuncsCache {
121+
template <KernelType KT, typename KernelTuples, typename PlaceType>
122+
class KernelFuncs {
123123
public:
124-
KernelFuncsCache() = default;
125-
static KernelFuncsCache& Instance() {
126-
static thread_local KernelFuncsCache<KT, KernelTuples> g_func_cache;
124+
KernelFuncs() = default;
125+
static KernelFuncs& Cache() {
126+
static thread_local KernelFuncs<KT, KernelTuples, PlaceType> g_func_cache;
127127
return g_func_cache;
128128
}
129129

130130
bool Has(int key) const { return funcs_.find(key) != funcs_.end(); }
131131

132-
typename KernelTuples::func_type At(int key) { return funcs_.at(key); }
133-
134132
void Insert(int key, typename KernelTuples::func_type func) {
135133
funcs_.emplace(key, func);
136134
}
137135

136+
typename KernelTuples::func_type At(int key) {
137+
if (Has(key)) {
138+
return funcs_.at(key);
139+
}
140+
auto func = Get<KT, KernelTuples, PlaceType>(key);
141+
Insert(key, func);
142+
return func;
143+
}
144+
138145
private:
139146
std::unordered_map<int, typename KernelTuples::func_type> funcs_;
140-
DISABLE_COPY_AND_ASSIGN(KernelFuncsCache);
147+
DISABLE_COPY_AND_ASSIGN(KernelFuncs);
141148
};
142149

143150
const char* to_string(KernelType kt);

paddle/fluid/operators/jit/more/mix/mix.cc

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -49,49 +49,16 @@ void VTanh(const T* x, T* y, int n) {
4949
}
5050

5151
void Softmax(const T* x, T* y, int n, int bs) {
52-
typename XRNTuples<T>::func_type compute_hmax{nullptr};
53-
typename XRNTuples<T>::func_type compute_hsum{nullptr};
54-
typename AXYNTuples<T>::func_type compute_vscal{nullptr};
55-
typename AXYNTuples<T>::func_type compute_vaddbias{nullptr};
56-
typename XYNTuples<T>::func_type compute_vexp{nullptr};
57-
58-
if (!KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Has(n)) {
59-
compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n);
60-
KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Insert(n, compute_hmax);
61-
} else {
62-
compute_hmax = KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().At(n);
63-
}
64-
65-
if (!KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Has(n)) {
66-
compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
67-
KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Insert(n, compute_hsum);
68-
} else {
69-
compute_hsum = KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().At(n);
70-
}
71-
72-
if (!KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Has(n)) {
73-
compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
74-
KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Insert(n,
75-
compute_vscal);
76-
} else {
77-
compute_vscal = KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().At(n);
78-
}
79-
80-
if (!KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Has(n)) {
81-
compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
82-
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Insert(
83-
n, compute_vaddbias);
84-
} else {
85-
compute_vaddbias =
86-
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().At(n);
87-
}
88-
89-
if (!KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Has(n)) {
90-
compute_vexp = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
91-
KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Insert(n, compute_vexp);
92-
} else {
93-
compute_vexp = KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().At(n);
94-
}
52+
auto compute_hmax =
53+
KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
54+
auto compute_hsum =
55+
KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
56+
auto compute_vscal =
57+
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
58+
auto compute_vaddbias =
59+
KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
60+
auto compute_vexp =
61+
KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
9562

9663
for (int i = 0; i < bs; ++i) {
9764
T scalar;

paddle/fluid/operators/jit/more/mkl/mkl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ bool VMulKernel<float>::UseMe(const int& d) const {
136136

137137
template <>
138138
bool VAddKernel<float>::UseMe(const int& d) const {
139-
return platform::MayIUse(platform::avx512f) && d > 512;
139+
return platform::MayIUse(platform::avx) && d > 512;
140140
}
141141

142142
template <>

paddle/fluid/operators/math/fc_compute.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,17 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
3030
return;
3131
}
3232
if (relu) {
33-
auto compute =
34-
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(N);
33+
auto compute = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>,
34+
platform::CPUPlace>::Cache()
35+
.At(N);
3536
for (int i = 0; i < M; i++) {
3637
T* dst = Y + i * N;
3738
compute(B, dst, dst, N);
3839
}
3940
} else {
40-
auto compute =
41-
jit::Get<jit::kVAdd, jit::XYZNTuples<T>, platform::CPUPlace>(N);
41+
auto compute = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<T>,
42+
platform::CPUPlace>::Cache()
43+
.At(N);
4244
#ifdef PADDLE_WITH_MKLML
4345
#pragma omp parallel for
4446
#endif

paddle/fluid/operators/math/softmax_impl.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
8282
const int kClassDim = 1;
8383
// 2D data. Batch x C
8484
auto compute_softmax =
85-
jit::Get<jit::kSoftmax, jit::SoftmaxTuples<float>, platform::CPUPlace>(
86-
in_dims[kClassDim]);
85+
jit::KernelFuncs<jit::kSoftmax, jit::SoftmaxTuples<float>,
86+
platform::CPUPlace>::Cache()
87+
.At(in_dims[kClassDim]);
8788
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
8889
}
8990
};

0 commit comments

Comments
 (0)