Skip to content

Commit 612ba41

Browse files
committed
add simple lstm compute
1 parent 83035e9 commit 612ba41

File tree

5 files changed

+123
-2
lines changed

5 files changed

+123
-2
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ op_library(flatten_op DEPS reshape_op)
296296
op_library(sequence_pad_op DEPS sequence_padding)
297297
op_library(unstack_op DEPS stack_op)
298298
op_library(fake_quantize_op DEPS memory)
299+
op_library(fusion_lstm_op DEPS cpu_lstm_compute)
299300

300301
if (WITH_GPU)
301302
op_library(conv_op DEPS vol2col depthwise_conv im2col)

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/fusion_lstm_op.h"
1616
#include <string>
1717
#include "paddle/fluid/operators/math/blas.h"
18+
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
1819
#include "paddle/fluid/operators/math/cpu_vec.h"
1920
#include "paddle/fluid/operators/math/fc_compute.h"
2021
#include "paddle/fluid/operators/math/sequence2batch.h"
@@ -269,7 +270,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
269270
blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
270271
wh_data, D4, static_cast<T>(1), out, D4)
271272

272-
// gates: W_ch, W_ih, W_fh, W_oh
273273
#define GET_Ct(ct_1, gates, ct) \
274274
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
275275
act_cand(D, gates, gates); \
@@ -395,11 +395,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
395395
}
396396
}
397397
} else {
398+
// TODO(TJ): unly workaround, clean me
399+
std::function<void(const T*, const T*, T*, T*)> compute_ctht;
400+
if (platform::jit::MayIUse(platform::jit::avx) &&
401+
act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
402+
act_cell_str == "tanh" && D == 8) {
403+
compute_ctht = math::lstm_compute_ctht<T>;
404+
} else {
405+
compute_ctht = [&](const T* gates, const T* ct_1, T* ct, T* ht) {
406+
COMPUTE_CtHt(gates, ct_1, ct, ht);
407+
}
408+
}
398409
for (int i = 0; i < N; ++i) {
399410
PROCESS_H0C0
400411
for (int step = tstart; step < seq_len; ++step) {
401412
GEMM_WH_ADDON(1, prev_h_data, xx_data);
402-
COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data);
413+
compute_ctht(xx_data, prev_c_data, c_out_data, h_out_data);
403414
MOVE_ONE_STEP;
404415
}
405416
}

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ math_library(im2col)
4545
if (NOT WIN32) # windows do not support avx functions yet.
4646
math_library(gru_compute DEPS activation_functions math_function)
4747
math_library(lstm_compute DEPS activation_functions)
48+
# TODO(TJ): ugly workaround, clean me
49+
cc_library(cpu_lstm_compute SRCS cpu_lstm_compute.cc DEPS activation_functions cblas)
4850
endif (NOT WIN32)
4951

5052
cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
16+
#ifdef __AVX__
17+
#include <immintrin.h>
18+
#endif
19+
namespace paddle {
20+
namespace operators {
21+
namespace math {
22+
23+
#ifdef __AVX__
24+
// TODO(TJ): ugly workaround, clean me
25+
26+
namespace detail {
27+
namespace forward {
28+
namespace avx {} // namespace avx
29+
} // namespace forward
30+
} // namespace detail
31+
32+
template <>
33+
void lstm_compute_ctht<float>(const float* gates, const float* ct_1, float* ct,
34+
float* ht) {
35+
namespace act = detail::forward::avx;
36+
// gates: W_ch, W_ih, W_fh, W_oh
37+
__m256 c, i, f, o;
38+
c = _mm256_loadu_ps(gates);
39+
i = _mm256_loadu_ps(gates + 8);
40+
f = _mm256_loadu_ps(gates + 16);
41+
o = _mm256_loadu_ps(gates + 24);
42+
43+
/* C_t = C_t-1 * fgated + cand_gated * igated*/
44+
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
45+
i = _mm256_loadu_ps(ct_1);
46+
f = _mm256_mul_ps(i, act::Sigmoid(f));
47+
f = _mm256_add_ps(c, f);
48+
_mm256_storeu_ps(ct, f);
49+
50+
/* H_t = act_cell(C_t) * ogated */
51+
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
52+
_mm256_storeu_ps(ht, o);
53+
}
54+
#endif
55+
} // namespace math
56+
} // namespace operators
57+
} // namespace paddle
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <string>
17+
#include "paddle/fluid/operators/math/cpu_vec.h"
18+
#include "paddle/fluid/platform/cpu_info.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
namespace math {
23+
24+
// TODO(TJ): ugly workaround, clean me
25+
template <typename T>
26+
void lstm_compute_ctht(const T* gates, const T* ct_1, T* ct, T* ht) {
27+
// gates: W_ch, W_ih, W_fh, W_oh
28+
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
29+
vec_tanh<T, platform::jit::avx>(8, gates, gates);
30+
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
31+
for (int d = 0; d < 8; ++d) {
32+
// C_t = C_t-1 * fgated + cand_gated * igated
33+
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
34+
35+
// H_t = act_cell(C_t) * ogated
36+
T tmp = ct[d] * 2;
37+
tmp = static_cast<T>(0) - (tmp < static_cast<T>(SIGMOID_THRESHOLD_MIN))
38+
? min
39+
: ((tmp > static_cast<T>(SIGMOID_THRESHOLD_MAX))
40+
? static_cast<T>(SIGMOID_THRESHOLD_MAX)
41+
: tmp);
42+
vec_exp<T>(1, &tmp, &tmp);
43+
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
44+
ht[d] = tmp * o[d];
45+
}
46+
}
47+
48+
} // namespace math
49+
} // namespace operators
50+
} // namespace paddle

0 commit comments

Comments
 (0)