Skip to content

Commit 159be8c

Browse files
committed
optimize fusion gru kernel at size 8
1 parent 83dc689 commit 159be8c

File tree

2 files changed

+123
-55
lines changed

2 files changed

+123
-55
lines changed

paddle/fluid/operators/math/jit_kernel_rnn.cc

Lines changed: 117 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,21 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
136136
return nullptr;
137137
}
138138

139+
template <jit::cpu_isa_t isa>
140+
static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
141+
if (type == "sigmoid") {
142+
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>());
143+
} else if (type == "relu") {
144+
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>());
145+
} else if (type == "tanh") {
146+
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>());
147+
} else if (type == "identity" || type == "") {
148+
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>());
149+
}
150+
PADDLE_THROW("Not support type: %s", type);
151+
return nullptr;
152+
}
153+
139154
/* LSTM JitKernel */
140155
template <typename T, jit::cpu_isa_t isa, jit_block>
141156
class LSTMKernelImpl : public LSTMKernel<T> {
@@ -192,61 +207,49 @@ class LSTMKernelImpl : public LSTMKernel<T> {
192207
#endif
193208
};
194209

195-
#define INTRI8_FLOAT(isa) \
196-
template <> \
197-
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
198-
const std::string& act_gate, const std::string& act_cand, \
199-
const std::string& act_cell, int d) \
200-
: LSTMKernel<float>() { \
201-
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
202-
if (type == "sigmoid") { \
203-
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \
204-
} else if (type == "relu") { \
205-
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \
206-
} else if (type == "tanh") { \
207-
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \
208-
} else if (type == "identity" || type == "") { \
209-
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \
210-
} \
211-
PADDLE_THROW("Not support type: %s", type); \
212-
}; \
213-
avx_act_gate_ = GetAVXAct(act_gate); \
214-
avx_act_cand_ = GetAVXAct(act_cand); \
215-
avx_act_cell_ = GetAVXAct(act_cell); \
216-
} \
217-
template <> \
218-
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
219-
float* gates, const float* ct_1, float* ct, float* ht, \
220-
const float* wp_data, float* checked) const { \
221-
/* gates: W_ch, W_ih, W_fh, W_oh */ \
222-
__m256 c, i, f, o; \
223-
c = _mm256_loadu_ps(gates); \
224-
i = _mm256_loadu_ps(gates + 8); \
225-
f = _mm256_loadu_ps(gates + 16); \
226-
o = _mm256_loadu_ps(gates + 24); \
227-
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
228-
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
229-
i = _mm256_loadu_ps(ct_1); \
230-
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
231-
f = _mm256_add_ps(c, f); \
232-
_mm256_storeu_ps(ct, f); \
233-
/* H_t = act_cell(C_t) * ogated */ \
234-
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
235-
_mm256_storeu_ps(ht, o); \
236-
} \
237-
template <> \
238-
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
239-
float* gates, float* ct, float* ht, const float* wp_data) const { \
240-
__m256 c, i, o; \
241-
c = _mm256_loadu_ps(gates); \
242-
i = _mm256_loadu_ps(gates + 8); \
243-
o = _mm256_loadu_ps(gates + 24); \
244-
/* C_t = igated * cgated*/ \
245-
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
246-
_mm256_storeu_ps(ct, c); \
247-
/* H_t = act_cell(C_t) * ogated */ \
248-
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
249-
_mm256_storeu_ps(ht, o); \
210+
#define INTRI8_FLOAT(isa) \
211+
template <> \
212+
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
213+
const std::string& act_gate, const std::string& act_cand, \
214+
const std::string& act_cell, int d) \
215+
: LSTMKernel<float>() { \
216+
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
217+
avx_act_cand_ = GetAVXAct<isa>(act_cand); \
218+
avx_act_cell_ = GetAVXAct<isa>(act_cell); \
219+
} \
220+
template <> \
221+
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
222+
float* gates, const float* ct_1, float* ct, float* ht, \
223+
const float* wp_data, float* checked) const { \
224+
/* gates: W_ch, W_ih, W_fh, W_oh */ \
225+
__m256 c, i, f, o; \
226+
c = _mm256_loadu_ps(gates); \
227+
i = _mm256_loadu_ps(gates + 8); \
228+
f = _mm256_loadu_ps(gates + 16); \
229+
o = _mm256_loadu_ps(gates + 24); \
230+
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
231+
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
232+
i = _mm256_loadu_ps(ct_1); \
233+
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
234+
f = _mm256_add_ps(c, f); \
235+
_mm256_storeu_ps(ct, f); \
236+
/* H_t = act_cell(C_t) * ogated */ \
237+
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
238+
_mm256_storeu_ps(ht, o); \
239+
} \
240+
template <> \
241+
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
242+
float* gates, float* ct, float* ht, const float* wp_data) const { \
243+
__m256 c, i, o; \
244+
c = _mm256_loadu_ps(gates); \
245+
i = _mm256_loadu_ps(gates + 8); \
246+
o = _mm256_loadu_ps(gates + 24); \
247+
/* C_t = igated * cgated*/ \
248+
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
249+
_mm256_storeu_ps(ct, c); \
250+
/* H_t = act_cell(C_t) * ogated */ \
251+
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
252+
_mm256_storeu_ps(ht, o); \
250253
}
251254

252255
// TODO(TJ): optimize keq16
@@ -375,6 +378,7 @@ class GRUKernelImpl : public GRUKernel<T> {
375378
act_state_d_->Compute(gates + d2_, gates + d2_);
376379
vmul_d_->Compute(gates, gates + d2_, ht);
377380
}
381+
378382
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
379383
// W: {W_update, W_reset; W_state}
380384
act_gate_d2_->Compute(gates, gates);
@@ -394,8 +398,65 @@ class GRUKernelImpl : public GRUKernel<T> {
394398
int d_, d2_;
395399
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_;
396400
std::shared_ptr<const VMulKernel<T>> vmul_d_;
401+
#ifdef __AVX__
402+
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_state_;
403+
#endif
397404
};
398405

406+
#define INTRI8_FLOAT(isa) \
407+
template <> \
408+
GRUKernelImpl<float, isa, kEQ8>::GRUKernelImpl( \
409+
const std::string& act_gate, const std::string& act_state, int d) \
410+
: GRUKernel<float>() { \
411+
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
412+
avx_act_state_ = GetAVXAct<isa>(act_state); \
413+
} \
414+
template <> \
415+
void GRUKernelImpl<float, isa, kEQ8>::ComputeH1(float* gates, float* ht) \
416+
const { \
417+
__m256 u, s; \
418+
/* W: {W_update, W_reset; W_state} */ \
419+
u = _mm256_loadu_ps(gates); \
420+
s = _mm256_loadu_ps(gates + 16); \
421+
s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \
422+
_mm256_storeu_ps(ht, s); \
423+
} \
424+
template <> \
425+
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart1( \
426+
float* gates, const float* ht_1, float* ht) const { \
427+
/* not exactly equal the any implementation */ \
428+
__m256 r, ht0; \
429+
r = _mm256_loadu_ps(gates + 8); \
430+
ht0 = _mm256_loadu_ps(ht_1); \
431+
r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \
432+
_mm256_storeu_ps(ht, r); \
433+
} \
434+
template <> \
435+
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart2( \
436+
float* gates, const float* ht_1, float* ht) const { \
437+
/* not exactly equal the any implementation */ \
438+
__m256 u, s, ht0; \
439+
u = _mm256_loadu_ps(gates); \
440+
s = _mm256_loadu_ps(gates + 16); \
441+
ht0 = _mm256_loadu_ps(ht_1); \
442+
u = avx_act_gate_->Compute(u); \
443+
s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \
444+
u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \
445+
u = _mm256_mul_ps(u, ht0); \
446+
u = _mm256_add_ps(s, u); \
447+
_mm256_storeu_ps(ht, u); \
448+
}
449+
450+
#ifdef __AVX__
451+
INTRI8_FLOAT(jit::avx);
452+
#endif
453+
#ifdef __AVX2__
454+
INTRI8_FLOAT(jit::avx2);
455+
#endif
456+
#ifdef __AVX512F__
457+
INTRI8_FLOAT(jit::avx512f);
458+
#endif
459+
399460
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
400461
template <> \
401462
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
@@ -412,6 +473,7 @@ class GRUKernelImpl : public GRUKernel<T> {
412473
REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU,
413474
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
414475

476+
#undef INTRI8_FLOAT
415477
#undef JITKERNEL_NEW_GRU_IMPL
416478
#undef JITKERNEL_KEY_GRU
417479
#undef JITKERNEL_DECLARE_GRU

python/paddle/fluid/tests/unittests/test_fusion_gru_op.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ def set_confs(self):
125125
self.D = 8
126126

127127

128+
class TestFusionGRUOpMD3(TestFusionGRUOp):
129+
def set_confs(self):
130+
self.M = 17
131+
self.D = 15
132+
133+
128134
class TestFusionGRUOpBS1(TestFusionGRUOp):
129135
def set_confs(self):
130136
self.lod = [[3]]

0 commit comments

Comments
 (0)