Skip to content

Commit 34aac18

Browse files
authored
Merge pull request #3 from reyoung/pr/4929
Several Enhancement
2 parents 694bc64 + 65906ef commit 34aac18

File tree

9 files changed

+102
-97
lines changed

9 files changed

+102
-97
lines changed

paddle/operators/lstm_op.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class LSTMOp : public framework::OperatorWithKernel {
6868
} else {
6969
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
7070
"The second dimension of Input(Bias) should be "
71-
"4 * %d if diable peepholes connection",
71+
"4 * %d if disable peepholes connection",
7272
frame_size);
7373
}
7474
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size});
@@ -86,7 +86,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
8686
AddInput("Input",
8787
"(LoDTensor) the first input is a LodTensor, which support "
8888
"variable-time length input sequence. The underlying tensor in "
89-
"this LoDTenosr is a matrix with shape (T X 4D), where, T is the "
89+
"this LoDTensor is a matrix with shape (T X 4D), where, T is the "
9090
"total time steps in this mini-batch, D is the hidden size.");
9191
AddInput("H0",
9292
"(Tensor, optional) the initial hidden state is an optional "
@@ -112,7 +112,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
112112
" - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}.");
113113
AddOutput("BatchGate",
114114
"(LoDTensor) This LoDTensor contains input gate, forget gate "
115-
"and output gate aftern the nonlinear computation. This "
115+
"and output gate after the nonlinear computation. This "
116116
"LoDTensor has the same shape with the reorganized input, which "
117117
"was also be called batch input. The LoD size is 2. The first "
118118
"LoD is the batch offsets and the second LoD contains the "
@@ -135,18 +135,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
135135
.SetDefault(false);
136136
AddAttr<std::string>(
137137
"gateActivation",
138-
"(string, defalut: sigmoid)"
138+
"(string, default: sigmoid)"
139139
"The activation for input gate, forget gate and output "
140-
"gate, `sigmoid` by defalut.")
140+
"gate, `sigmoid` by default.")
141141
.SetDefault("sigmoid");
142142
AddAttr<std::string>("cellActivation",
143-
"(string, defalut: tanh)"
143+
"(string, default: tanh)"
144144
"The activation for cell output, `tanh` by defalut.")
145145
.SetDefault("tanh");
146146
AddAttr<std::string>("candidateActivation",
147-
"(string, defalut: tanh)"
147+
"(string, default: tanh)"
148148
"The activation for candidate hidden state, "
149-
"`tanh` by defalut.")
149+
"`tanh` by default.")
150150
.SetDefault("tanh");
151151
AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator
152152

paddle/operators/lstm_op.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class LSTMKernel : public framework::OpKernel<T> {
5252
to_batch(ctx.device_context(), *input, *batch_gate, is_reverse);
5353

5454
auto in_dims = input->dims();
55-
int frame_size = in_dims[1] / 4;
55+
int frame_size = static_cast<int>(in_dims[1] / 4);
5656
framework::DDim dims({in_dims[0], frame_size});
5757

5858
if (bias) {
@@ -70,7 +70,7 @@ class LSTMKernel : public framework::OpKernel<T> {
7070

7171
math::LstmMetaValue<T> lstm_value;
7272
T* bias_data = const_cast<T*>(bias->data<T>());
73-
// the code styple in LstmMetaValue will be updated later.
73+
// the code style in LstmMetaValue will be updated later.
7474
lstm_value.checkIg = bias_data + 4 * frame_size;
7575
lstm_value.checkFg = lstm_value.checkIg + frame_size;
7676
lstm_value.checkOg = lstm_value.checkFg + frame_size;
@@ -83,15 +83,15 @@ class LSTMKernel : public framework::OpKernel<T> {
8383
framework::LoDTensor batch_cell_pre_act;
8484
batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace());
8585

86-
auto batch_lod = batch_gate->lod()[0];
87-
int num_batch = batch_lod.size() - 1;
86+
auto& batch_starts = batch_gate->lod()[0];
87+
size_t num_batch = batch_starts.size() - 1;
8888
auto gate_act = ctx.Attr<std::string>("gateActivation");
8989
auto cell_act = ctx.Attr<std::string>("cellActivation");
9090
auto cand_act = ctx.Attr<std::string>("candidateActivation");
9191

92-
for (int n = 0; n < num_batch; n++) {
93-
int bstart = batch_lod[n];
94-
int bend = batch_lod[n + 1];
92+
for (size_t n = 0; n < num_batch; n++) {
93+
int bstart = static_cast<int>(batch_starts[n]);
94+
int bend = static_cast<int>(batch_starts[n + 1]);
9595

9696
Tensor gate_t = batch_gate->Slice<T>(bstart, bend);
9797
Tensor out_t = batch_out.Slice<T>(bstart, bend);
@@ -101,14 +101,14 @@ class LSTMKernel : public framework::OpKernel<T> {
101101
int cur_batch_size = bend - bstart;
102102

103103
if (n != 0) {
104-
int pre_h_start = batch_lod[n - 1];
104+
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
105105
int pre_h_end = pre_h_start + cur_batch_size;
106106
auto pre_hidden_t = batch_out.Slice<T>(pre_h_start, pre_h_end);
107107
math::matmul<Place, T>(ctx.device_context(), pre_hidden_t, false,
108108
*weight, false, static_cast<T>(1.0), &gate_t,
109109
static_cast<T>(1.0));
110110
}
111-
// else if : support the initial hidden and cell
111+
// else if : FIXME support the initial hidden and cell
112112

113113
lstm_value.gateValue = gate_t.data<T>();
114114
lstm_value.outputValue = out_t.data<T>();

paddle/operators/math/detail/lstm_kernel.h

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/operators/math/detail/hl_activation_functions.h"
16+
#include "paddle/platform/hostdevice.h"
1617

17-
#ifdef __CUDA_ARCH__
18-
#define INLINE __device__ inline
19-
#else
20-
#define INLINE inline
21-
#endif
18+
#include <type_traits>
2219

2320
namespace paddle {
2421
namespace operators {
@@ -30,12 +27,12 @@ namespace forward {
3027
template <class T>
3128
class lstm {
3229
public:
33-
INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
34-
T &prevState, T &state, T &stateAtv, T &output,
35-
T &checkI, T &checkF, T &checkO,
36-
typename hppl::ForwardActType<T>::type actInput,
37-
typename hppl::ForwardActType<T>::type actGate,
38-
typename hppl::ForwardActType<T>::type actState) {
30+
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
31+
T &prevState, T &state, T &stateAtv, T &output,
32+
T &checkI, T &checkF, T &checkO,
33+
typename hppl::ForwardActType<T>::type actInput,
34+
typename hppl::ForwardActType<T>::type actGate,
35+
typename hppl::ForwardActType<T>::type actState) {
3936
valueIn = actInput(valueIn);
4037
valueIg = actGate(valueIg + prevState * checkI);
4138
valueFg = actGate(valueFg + prevState * checkF);
@@ -45,17 +42,19 @@ class lstm {
4542
output = valueOg * stateAtv;
4643
}
4744
#ifndef __NVCC__
48-
#ifndef __AVX__
45+
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
4946
static const bool avx = false;
5047
#else
51-
static const bool avx = true;
52-
INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
53-
__m256 &valueOg, __m256 &prevState, __m256 &state,
54-
__m256 &stateAtv, __m256 &output, __m256 &checkI,
55-
__m256 &checkF, __m256 &checkO,
56-
hppl::Active<__m256>::forward actInput,
57-
hppl::Active<__m256>::forward actGate,
58-
hppl::Active<__m256>::forward actState) {
48+
// Only float support AVX optimization
49+
static const bool avx = std::is_same<T, float>::value;
50+
51+
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
52+
__m256 &valueOg, __m256 &prevState, __m256 &state,
53+
__m256 &stateAtv, __m256 &output, __m256 &checkI,
54+
__m256 &checkF, __m256 &checkO,
55+
hppl::Active<__m256>::forward actInput,
56+
hppl::Active<__m256>::forward actGate,
57+
hppl::Active<__m256>::forward actState) {
5958
valueIn = actInput(valueIn);
6059
valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)));
6160
valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)));
@@ -76,14 +75,15 @@ namespace backward {
7675
template <class T>
7776
class lstm {
7877
public:
79-
INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
80-
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
81-
T &prevState, T &prevStateGrad, T &state, T &stateGrad,
82-
T &stateAtv, T &outputGrad, T &checkI, T &checkF,
83-
T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad,
84-
typename hppl::BackwardActType<T>::type actInput,
85-
typename hppl::BackwardActType<T>::type actGate,
86-
typename hppl::BackwardActType<T>::type actState) {
78+
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
79+
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
80+
T &prevState, T &prevStateGrad, T &state,
81+
T &stateGrad, T &stateAtv, T &outputGrad,
82+
T &checkI, T &checkF, T &checkO, T &checkIGrad,
83+
T &checkFGrad, T &checkOGrad,
84+
typename hppl::BackwardActType<T>::type actInput,
85+
typename hppl::BackwardActType<T>::type actGate,
86+
typename hppl::BackwardActType<T>::type actState) {
8787
gradOg = actGate(outputGrad * stateAtv, valueOg);
8888
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
8989
gradIn = actInput(stateGrad * valueIg, valueIn);
@@ -95,21 +95,22 @@ class lstm {
9595
checkOGrad = gradOg * state;
9696
}
9797
#ifndef __NVCC__
98-
#ifndef __AVX__
98+
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
9999
static const bool avx = false;
100100
#else
101-
static const bool avx = true;
102-
INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
103-
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
104-
__m256 &gradFg, __m256 &gradOg, __m256 &prevState,
105-
__m256 &prevStateGrad, __m256 &state,
106-
__m256 &stateGrad, __m256 &stateAtv,
107-
__m256 &outputGrad, __m256 &checkI, __m256 &checkF,
108-
__m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad,
109-
__m256 &checkOGrad,
110-
hppl::Active<__m256>::backward actInput,
111-
hppl::Active<__m256>::backward actGate,
112-
hppl::Active<__m256>::backward actState) {
101+
// Only float support AVX optimization
102+
static const bool avx = std::is_same<T, float>::value;
103+
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
104+
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
105+
__m256 &gradFg, __m256 &gradOg, __m256 &prevState,
106+
__m256 &prevStateGrad, __m256 &state,
107+
__m256 &stateGrad, __m256 &stateAtv,
108+
__m256 &outputGrad, __m256 &checkI, __m256 &checkF,
109+
__m256 &checkO, __m256 &checkIGrad,
110+
__m256 &checkFGrad, __m256 &checkOGrad,
111+
hppl::Active<__m256>::backward actInput,
112+
hppl::Active<__m256>::backward actGate,
113+
hppl::Active<__m256>::backward actState) {
113114
gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg);
114115
stateGrad = _mm256_add_ps(
115116
actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad);

paddle/operators/math/lstm_compute.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ template <class T>
2424
struct LstmUnitFunctor<platform::CPUPlace, T> {
2525
static void compute(const platform::DeviceContext& context,
2626
LstmMetaValue<T> value, int frame_size, int batch_size,
27-
std::string gate_act, std::string cell_act,
28-
std::string cand_act) {
27+
const std::string& gate_act, const std::string& cell_act,
28+
const std::string& cand_act) {
2929
for (int b = 0; b < batch_size; b++) {
3030
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
3131
ActiveType(cand_act), ActiveType(gate_act),
@@ -45,8 +45,9 @@ template <class T>
4545
struct LstmUnitGradFunctor<platform::CPUPlace, T> {
4646
static void compute(const platform::DeviceContext& context,
4747
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
48-
int frame_size, int batch_size, std::string gate_act,
49-
std::string cell_act, std::string cand_act) {
48+
int frame_size, int batch_size,
49+
const std::string& gate_act, const std::string& cell_act,
50+
const std::string& cand_act) {
5051
for (int b = 0; b < batch_size; b++) {
5152
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
5253
frame_size, ActiveType(cand_act),

paddle/operators/math/lstm_compute.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ template <class T>
2424
struct LstmUnitFunctor<platform::GPUPlace, T> {
2525
static void compute(const platform::DeviceContext& context,
2626
LstmMetaValue<T> value, int frame_size, int batch_size,
27-
std::string gate_act, std::string cell_act,
28-
std::string cand_act) {
27+
const std::string& gate_act, const std::string& cell_act,
28+
const std::string& cand_act) {
2929
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
3030
frame_size, batch_size, ActiveType(cand_act),
3131
ActiveType(gate_act), ActiveType(cell_act));
@@ -36,8 +36,9 @@ template <class T>
3636
struct LstmUnitGradFunctor<platform::GPUPlace, T> {
3737
static void compute(const platform::DeviceContext& context,
3838
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
39-
int frame_size, int batch_size, std::string gate_act,
40-
std::string cell_act, std::string cand_act) {
39+
int frame_size, int batch_size,
40+
const std::string& gate_act, const std::string& cell_act,
41+
const std::string& cand_act) {
4142
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
4243
frame_size, batch_size, ActiveType(cand_act),
4344
ActiveType(gate_act), ActiveType(cell_act));

paddle/operators/math/lstm_compute.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,18 @@ class LstmUnitFunctor {
7272
public:
7373
static void compute(const platform::DeviceContext &context,
7474
LstmMetaValue<T> value, int frame_size, int batch_size,
75-
std::string gate_act, std::string cell_act,
76-
std::string cand_act);
75+
const std::string &gate_act, const std::string &cell_act,
76+
const std::string &cand_act);
7777
};
7878

7979
template <typename Place, typename T>
8080
class LstmUnitGradFunctor {
8181
public:
8282
static void compute(const platform::DeviceContext &context,
8383
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
84-
int frame_size, int batch_size, std::string gate_act,
85-
std::string cell_act, std::string cand_act);
84+
int frame_size, int batch_size,
85+
const std::string &gate_act, const std::string &cell_act,
86+
const std::string &cand_act);
8687
};
8788

8889
} // namespace math

paddle/operators/math/sequence2batch.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
5151
template class CopyMatrixRowsFunctor<platform::CPUPlace, float>;
5252
template class CopyMatrixRowsFunctor<platform::CPUPlace, double>;
5353

54-
template class LoDTensor2BatchFunctor<platform::CPUPlace, float>;
55-
template class LoDTensor2BatchFunctor<platform::CPUPlace, double>;
5654
template class Batch2LoDTensorFunctor<platform::CPUPlace, float>;
5755
template class Batch2LoDTensorFunctor<platform::CPUPlace, double>;
5856

paddle/operators/math/sequence2batch.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace math {
2121
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
2222
__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index,
2323
int64_t height, int64_t width,
24-
const bool is_src_index) {
24+
bool is_src_index) {
2525
int idx = threadIdx.x;
2626
int idy = threadIdx.y;
2727
int id = blockIdx.x + idy * GridDimX;

0 commit comments

Comments
 (0)