Skip to content

Commit 7c87308

Browse files
authored
Merge pull request #13396 from tensor-tang/refine/op/lstm
Refine/op/lstm
2 parents edb9e56 + e09cf03 commit 7c87308

File tree

6 files changed

+146
-4
lines changed

6 files changed

+146
-4
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ op_library(flatten_op DEPS reshape_op)
296296
op_library(sequence_pad_op DEPS sequence_padding)
297297
op_library(unstack_op DEPS stack_op)
298298
op_library(fake_quantize_op DEPS memory)
299+
op_library(fusion_lstm_op DEPS cpu_lstm_compute)
299300

300301
if (WITH_GPU)
301302
op_library(conv_op DEPS vol2col depthwise_conv im2col)

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/fusion_lstm_op.h"
1616
#include <string>
1717
#include "paddle/fluid/operators/math/blas.h"
18+
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
1819
#include "paddle/fluid/operators/math/cpu_vec.h"
1920
#include "paddle/fluid/operators/math/fc_compute.h"
2021
#include "paddle/fluid/operators/math/sequence2batch.h"
@@ -269,7 +270,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
269270
blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
270271
wh_data, D4, static_cast<T>(1), out, D4)
271272

272-
// gates: W_ch, W_ih, W_fh, W_oh
273273
#define GET_Ct(ct_1, gates, ct) \
274274
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
275275
act_cand(D, gates, gates); \
@@ -395,11 +395,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
395395
}
396396
}
397397
} else {
398+
// TODO(TJ): unly workaround, clean me
399+
std::function<void(T*, const T*, T*, T*)> compute_ctht;
400+
if (platform::jit::MayIUse(platform::jit::avx) &&
401+
act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
402+
act_cell_str == "tanh" && D == 8) {
403+
compute_ctht = math::lstm_compute_ctht<T>;
404+
} else {
405+
compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
406+
COMPUTE_CtHt(gates, ct_1, ct, ht);
407+
};
408+
}
398409
for (int i = 0; i < N; ++i) {
399410
PROCESS_H0C0
400411
for (int step = tstart; step < seq_len; ++step) {
401412
GEMM_WH_ADDON(1, prev_h_data, xx_data);
402-
COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data);
413+
compute_ctht(xx_data, prev_c_data, c_out_data, h_out_data);
403414
MOVE_ONE_STEP;
404415
}
405416
}
@@ -532,12 +543,23 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
532543
MOVE_ONE_STEP;
533544
}
534545
} else {
546+
// TODO(TJ): unly workaround, clean me
547+
std::function<void(T*, const T*, T*, T*)> compute_ctht;
548+
if (platform::jit::MayIUse(platform::jit::avx) &&
549+
act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
550+
act_cell_str == "tanh" && D == 8) {
551+
compute_ctht = math::lstm_compute_ctht<T>;
552+
} else {
553+
compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
554+
COMPUTE_CtHt(gates, ct_1, ct, ht);
555+
};
556+
}
535557
for (int step = tstart; step < max_seq_len; ++step) {
536558
const int cur_bs = batch_starts[step + 1] - batch_starts[step];
537559
GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
538560
DEFINE_CUR;
539561
for (int i = 0; i < cur_bs; ++i) {
540-
COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data,
562+
compute_ctht(cur_in_data, cur_prev_c_data, cur_c_out_data,
541563
cur_h_out_data);
542564
MOVE_ONE_BATCH;
543565
}

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ math_library(im2col)
4545
if (NOT WIN32) # windows do not support avx functions yet.
4646
math_library(gru_compute DEPS activation_functions math_function)
4747
math_library(lstm_compute DEPS activation_functions)
48+
# TODO(TJ): ugly workaround, clean me
49+
cc_library(cpu_lstm_compute SRCS cpu_lstm_compute.cc DEPS activation_functions cblas cpu_info)
4850
endif (NOT WIN32)
4951

5052
cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
16+
#include "paddle/fluid/operators/math/cpu_vec.h"
17+
#include "paddle/fluid/platform/cpu_info.h"
18+
#ifdef __AVX__
19+
#include <immintrin.h>
20+
#endif
21+
22+
namespace paddle {
23+
namespace operators {
24+
namespace math {
25+
26+
// TODO(TJ): ugly workaround, clean me
27+
template <typename T>
28+
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
29+
// gates: W_ch, W_ih, W_fh, W_oh
30+
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
31+
vec_tanh<T, platform::jit::avx>(8, gates, gates);
32+
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
33+
const T min = SIGMOID_THRESHOLD_MIN;
34+
const T max = SIGMOID_THRESHOLD_MAX;
35+
for (int d = 0; d < 8; ++d) {
36+
// C_t = C_t-1 * fgated + cand_gated * igated
37+
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
38+
// H_t = act_cell(C_t) * ogated
39+
T tmp = ct[d] * 2;
40+
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
41+
vec_exp<T>(1, &tmp, &tmp);
42+
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
43+
ht[d] = tmp * o[d];
44+
}
45+
}
46+
47+
#ifdef __AVX__
48+
namespace detail {
49+
namespace forward {
50+
namespace avx {
51+
__m256 Sigmoid(const __m256 a);
52+
__m256 Tanh(const __m256 a);
53+
} // namespace avx
54+
} // namespace forward
55+
} // namespace detail
56+
57+
template <>
58+
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
59+
float* ht) {
60+
namespace act = detail::forward::avx;
61+
// gates: W_ch, W_ih, W_fh, W_oh
62+
__m256 c, i, f, o;
63+
c = _mm256_loadu_ps(gates);
64+
i = _mm256_loadu_ps(gates + 8);
65+
f = _mm256_loadu_ps(gates + 16);
66+
o = _mm256_loadu_ps(gates + 24);
67+
68+
/* C_t = C_t-1 * fgated + cand_gated * igated*/
69+
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
70+
i = _mm256_loadu_ps(ct_1);
71+
f = _mm256_mul_ps(i, act::Sigmoid(f));
72+
f = _mm256_add_ps(c, f);
73+
_mm256_storeu_ps(ct, f);
74+
75+
/* H_t = act_cell(C_t) * ogated */
76+
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
77+
_mm256_storeu_ps(ht, o);
78+
}
79+
#endif
80+
81+
template void lstm_compute_ctht<float>(float* gates, const float* ct_1,
82+
float* ct, float* ht);
83+
template void lstm_compute_ctht<double>(double* gates, const double* ct_1,
84+
double* ct, double* ht);
85+
86+
} // namespace math
87+
} // namespace operators
88+
} // namespace paddle
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <string>
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace math {
21+
22+
// TODO(TJ): ugly workaround, clean me
23+
template <typename T>
24+
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht);
25+
26+
} // namespace math
27+
} // namespace operators
28+
} // namespace paddle

paddle/fluid/operators/math/cpu_vec.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <functional>
1818
#include <string>
1919
#include "paddle/fluid/platform/cpu_info.h"
20+
#include "paddle/fluid/platform/enforce.h"
2021
#ifdef __AVX__
2122
#include <immintrin.h>
2223
#endif
@@ -476,7 +477,7 @@ class VecActivations {
476477
} else if (type == "identity" || type == "") {
477478
return vec_identity<T, isa>;
478479
}
479-
LOG(FATAL) << "Not support type: " << type;
480+
PADDLE_THROW("Not support type: %s", type);
480481
}
481482
};
482483

0 commit comments

Comments
 (0)