Skip to content

Commit 6e1ee7f

Browse files
committed
cache softmax kernel func
test=develop
1 parent c744922 commit 6e1ee7f

File tree

3 files changed

+28
-53
lines changed

3 files changed

+28
-53
lines changed

paddle/fluid/operators/jit/helper.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,26 +118,33 @@ typename KernelTuples::func_type Get(
118118
return GetRefer<KT, KernelTuples>();
119119
}
120120

121-
template <KernelType KT, typename KernelTuples>
122-
class KernelFuncsCache {
121+
template <KernelType KT, typename KernelTuples, typename PlaceType>
122+
class KernelFuncs {
123123
public:
124-
KernelFuncsCache() = default;
125-
static KernelFuncsCache& Instance() {
126-
static thread_local KernelFuncsCache<KT, KernelTuples> g_func_cache;
124+
KernelFuncs() = default;
125+
static KernelFuncs& Cache() {
126+
static thread_local KernelFuncs<KT, KernelTuples, PlaceType> g_func_cache;
127127
return g_func_cache;
128128
}
129129

130130
bool Has(int key) const { return funcs_.find(key) != funcs_.end(); }
131131

132-
typename KernelTuples::func_type At(int key) { return funcs_.at(key); }
133-
134132
void Insert(int key, typename KernelTuples::func_type func) {
135133
funcs_.emplace(key, func);
136134
}
137135

136+
typename KernelTuples::func_type At(int key) {
137+
if (Has(key)) {
138+
return funcs_.at(key);
139+
}
140+
auto func = Get<KT, KernelTuples, PlaceType>(key);
141+
Insert(key, func);
142+
return func;
143+
}
144+
138145
private:
139146
std::unordered_map<int, typename KernelTuples::func_type> funcs_;
140-
DISABLE_COPY_AND_ASSIGN(KernelFuncsCache);
147+
DISABLE_COPY_AND_ASSIGN(KernelFuncs);
141148
};
142149

143150
const char* to_string(KernelType kt);

paddle/fluid/operators/jit/more/mix/mix.cc

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -49,49 +49,16 @@ void VTanh(const T* x, T* y, int n) {
4949
}
5050

5151
void Softmax(const T* x, T* y, int n, int bs) {
52-
typename XRNTuples<T>::func_type compute_hmax{nullptr};
53-
typename XRNTuples<T>::func_type compute_hsum{nullptr};
54-
typename AXYNTuples<T>::func_type compute_vscal{nullptr};
55-
typename AXYNTuples<T>::func_type compute_vaddbias{nullptr};
56-
typename XYNTuples<T>::func_type compute_vexp{nullptr};
57-
58-
if (!KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Has(n)) {
59-
compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n);
60-
KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Insert(n, compute_hmax);
61-
} else {
62-
compute_hmax = KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().At(n);
63-
}
64-
65-
if (!KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Has(n)) {
66-
compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
67-
KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Insert(n, compute_hsum);
68-
} else {
69-
compute_hsum = KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().At(n);
70-
}
71-
72-
if (!KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Has(n)) {
73-
compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
74-
KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Insert(n,
75-
compute_vscal);
76-
} else {
77-
compute_vscal = KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().At(n);
78-
}
79-
80-
if (!KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Has(n)) {
81-
compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
82-
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Insert(
83-
n, compute_vaddbias);
84-
} else {
85-
compute_vaddbias =
86-
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().At(n);
87-
}
88-
89-
if (!KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Has(n)) {
90-
compute_vexp = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
91-
KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Insert(n, compute_vexp);
92-
} else {
93-
compute_vexp = KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().At(n);
94-
}
52+
auto compute_hmax =
53+
KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
54+
auto compute_hsum =
55+
KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
56+
auto compute_vscal =
57+
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
58+
auto compute_vaddbias =
59+
KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
60+
auto compute_vexp =
61+
KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
9562

9663
for (int i = 0; i < bs; ++i) {
9764
T scalar;

paddle/fluid/operators/math/softmax_impl.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
8282
const int kClassDim = 1;
8383
// 2D data. Batch x C
8484
auto compute_softmax =
85-
jit::Get<jit::kSoftmax, jit::SoftmaxTuples<float>, platform::CPUPlace>(
86-
in_dims[kClassDim]);
85+
jit::KernelFuncs<jit::kSoftmax, jit::SoftmaxTuples<float>,
86+
platform::CPUPlace>::Cache()
87+
.At(in_dims[kClassDim]);
8788
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
8889
}
8990
};

0 commit comments

Comments
 (0)