Skip to content

Commit b50c33f

Browse files
committed
Use fixed activation in the lstm kernel, since there is some bug in the activation function pointer. It will be fixed later.
1 parent bd680f1 commit b50c33f

File tree

5 files changed

+84
-49
lines changed

5 files changed

+84
-49
lines changed

paddle/operators/lstm_op.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ class LSTMOp : public framework::OperatorWithKernel {
8282
ctx->ShareLoD("Input", "Hidden");
8383
ctx->ShareLoD("Input", "Cell");
8484
}
85+
86+
protected:
87+
framework::DataType IndicateDataType(
88+
const framework::ExecutionContext& ctx) const override {
89+
return framework::ToDataType(
90+
ctx.Input<framework::LoDTensor>("Input")->type());
91+
}
8592
};
8693

8794
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -239,6 +246,13 @@ class LSTMGradOp : public framework::OperatorWithKernel {
239246
if (ctx->HasOutput(b_g_name))
240247
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias"));
241248
}
249+
250+
protected:
251+
framework::DataType IndicateDataType(
252+
const framework::ExecutionContext& ctx) const override {
253+
return framework::ToDataType(
254+
ctx.Input<framework::LoDTensor>("Input")->type());
255+
}
242256
};
243257

244258
} // namespace operators

paddle/operators/math/detail/lstm_cpu_kernel.h

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@ namespace detail {
2626

2727
template <class T, class Op>
2828
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
29-
int frameSize,
30-
activation_mode_t active_node,
31-
activation_mode_t active_gate,
32-
activation_mode_t active_state) {
29+
int frameSize) {
3330
T rValueIn;
3431
T rValueIg;
3532
T rValueFg;
@@ -60,10 +57,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
6057
rPrevState = value.prevStateValue[i];
6158
}
6259

63-
hppl::cpu::ForwardAct<T> act;
6460
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
65-
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate),
66-
act(active_state));
61+
rOut, rCheckI, rCheckF, rCheckO);
6762

6863
valueIn[i] = rValueIn;
6964
valueIg[i] = rValueIg;
@@ -77,10 +72,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
7772

7873
template <class T, class Op>
7974
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
80-
LstmMetaGrad<T> grad, int frameSize,
81-
activation_mode_t active_node,
82-
activation_mode_t active_gate,
83-
activation_mode_t active_state) {
75+
LstmMetaGrad<T> grad, int frameSize) {
8476
T rValueIn;
8577
T rValueIg;
8678
T rValueFg;
@@ -127,11 +119,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
127119
rPrevState = value.prevStateValue[i];
128120
}
129121

130-
hppl::cpu::BackwardAct<T> act;
131122
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
132123
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
133124
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
134-
rCheckOGrad, act(active_node), act(active_gate), act(active_state));
125+
rCheckOGrad);
135126

136127
gradIn[i] = rGradIn;
137128
gradIg[i] = rGradIg;
@@ -283,8 +274,7 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize,
283274
avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
284275
active_gate, active_state);
285276
} else {
286-
naive_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
287-
active_gate, active_state);
277+
naive_lstm_forward_one_sequence<T>(op, value, frameSize);
288278
}
289279
}
290280

@@ -297,8 +287,7 @@ void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
297287
avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
298288
active_gate, active_state);
299289
} else {
300-
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
301-
active_gate, active_state);
290+
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize);
302291
}
303292
}
304293

paddle/operators/math/detail/lstm_gpu_kernel.h

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ namespace detail {
3232
*/
3333
template <class T, class Op, bool isBatch>
3434
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
35-
int batchSize, activation_mode_t active_node,
36-
activation_mode_t active_gate,
37-
activation_mode_t active_state) {
35+
int batchSize) {
3836
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
3937
if (frameIdx >= frameSize) return;
4038

@@ -70,10 +68,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
7068
rPrevState = value.prevStateValue[frameIdx];
7169
}
7270

73-
hppl::gpu::ForwardAct<T> act;
7471
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
75-
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate),
76-
act(active_state));
72+
rOut, rCheckI, rCheckF, rCheckO);
7773

7874
value.gateValue[frameIdx] = rValueIn;
7975
value.gateValue[frameIdx + frameSize] = rValueIg;
@@ -92,9 +88,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
9288
template <class T, class Op, bool isBatch>
9389
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
9490
LstmMetaGrad<T> grad, int frameSize,
95-
int batchSize, activation_mode_t active_node,
96-
activation_mode_t active_gate,
97-
activation_mode_t active_state) {
91+
int batchSize) {
9892
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
9993
if (frameIdx >= frameSize) return;
10094

@@ -145,11 +139,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
145139
rPrevState = value.prevStateValue[frameIdx];
146140
}
147141

148-
hppl::gpu::BackwardAct<T> act;
149142
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
150143
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
151-
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad,
152-
act(active_node), act(active_gate), act(active_state));
144+
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad);
153145

154146
grad.gateGrad[frameIdx] = rGradIn;
155147
grad.gateGrad[frameIdx + frameSize] = rGradIg;
@@ -205,13 +197,11 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
205197
if (batchSize == 1) {
206198
KeLstmForward<T, Op,
207199
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
208-
op, value, frameSize, batchSize, active_node, active_gate,
209-
active_state);
200+
op, value, frameSize, batchSize);
210201
} else {
211202
KeLstmForward<T, Op,
212203
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
213-
op, value, frameSize, batchSize, active_node, active_gate,
214-
active_state);
204+
op, value, frameSize, batchSize);
215205
}
216206
}
217207

@@ -240,13 +230,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
240230
if (batchSize == 1) {
241231
KeLstmBackward<T, Op,
242232
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
243-
op, value, grad, frameSize, batchSize, active_node, active_gate,
244-
active_state);
233+
op, value, grad, frameSize, batchSize);
245234
} else {
246235
KeLstmBackward<T, Op,
247236
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
248-
op, value, grad, frameSize, batchSize, active_node, active_gate,
249-
active_state);
237+
op, value, grad, frameSize, batchSize);
250238
}
251239
}
252240

paddle/operators/math/detail/lstm_kernel.h

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,45 @@ namespace detail {
2424

2525
namespace forward {
2626

27+
template <typename T>
28+
DEVICE inline T sigmoid(const T a) {
29+
const T min = SIGMOID_THRESHOLD_MIN;
30+
const T max = SIGMOID_THRESHOLD_MAX;
31+
T tmp = (a < min) ? min : ((a > max) ? max : a);
32+
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp));
33+
}
34+
35+
template <typename T>
36+
DEVICE inline T tanh(const T a) {
37+
T tmp = -2.0 * a;
38+
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
39+
return (2.0 / (1.0 + exp(tmp))) - 1.0;
40+
}
41+
2742
template <class T>
2843
class lstm {
2944
public:
3045
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
3146
T &prevState, T &state, T &stateAtv, T &output,
32-
T &checkI, T &checkF, T &checkO,
33-
typename hppl::ForwardActType<T>::type actInput,
34-
typename hppl::ForwardActType<T>::type actGate,
35-
typename hppl::ForwardActType<T>::type actState) {
47+
T &checkI, T &checkF, T &checkO) {
48+
#if 0
49+
// TODO(qingqing) support to activation speficed by users
3650
valueIn = actInput(valueIn);
3751
valueIg = actGate(valueIg + prevState * checkI);
3852
valueFg = actGate(valueFg + prevState * checkF);
3953
state = valueIn * valueIg + prevState * valueFg;
4054
valueOg = actGate(valueOg + state * checkO);
4155
stateAtv = actState(state);
4256
output = valueOg * stateAtv;
57+
#else
58+
valueIn = tanh<T>(valueIn);
59+
valueIg = sigmoid<T>(valueIg + prevState * checkI);
60+
valueFg = sigmoid<T>(valueFg + prevState * checkF);
61+
state = valueIn * valueIg + prevState * valueFg;
62+
valueOg = sigmoid<T>(valueOg + state * checkO);
63+
stateAtv = tanh<T>(state);
64+
output = valueOg * stateAtv;
65+
#endif
4366
}
4467
#ifndef __NVCC__
4568
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
@@ -72,6 +95,16 @@ class lstm {
7295

7396
namespace backward {
7497

98+
template <typename T>
99+
DEVICE inline T sigmoid(const T a, const T b) {
100+
return a * b * (1.0 - b);
101+
}
102+
103+
template <typename T>
104+
DEVICE inline T tanh(const T a, const T b) {
105+
return a * (1.0 - b * b);
106+
}
107+
75108
template <class T>
76109
class lstm {
77110
public:
@@ -80,10 +113,9 @@ class lstm {
80113
T &prevState, T &prevStateGrad, T &state,
81114
T &stateGrad, T &stateAtv, T &outputGrad,
82115
T &checkI, T &checkF, T &checkO, T &checkIGrad,
83-
T &checkFGrad, T &checkOGrad,
84-
typename hppl::BackwardActType<T>::type actInput,
85-
typename hppl::BackwardActType<T>::type actGate,
86-
typename hppl::BackwardActType<T>::type actState) {
116+
T &checkFGrad, T &checkOGrad) {
117+
#if 0
118+
// TODO(qingqing) support to activation speficed by users
87119
gradOg = actGate(outputGrad * stateAtv, valueOg);
88120
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
89121
gradIn = actInput(stateGrad * valueIg, valueIn);
@@ -93,6 +125,17 @@ class lstm {
93125
checkIGrad = gradIg * prevState;
94126
checkFGrad = gradFg * prevState;
95127
checkOGrad = gradOg * state;
128+
#else
129+
gradOg = sigmoid<T>(outputGrad * stateAtv, valueOg);
130+
stateGrad += tanh<T>(outputGrad * valueOg, stateAtv) + gradOg * checkO;
131+
gradIn = tanh<T>(stateGrad * valueIg, valueIn);
132+
gradIg = sigmoid<T>(stateGrad * valueIn, valueIg);
133+
gradFg = sigmoid<T>(stateGrad * prevState, valueFg);
134+
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
135+
checkIGrad = gradIg * prevState;
136+
checkFGrad = gradFg * prevState;
137+
checkOGrad = gradOg * state;
138+
#endif
96139
}
97140
#ifndef __NVCC__
98141
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default

python/paddle/v2/framework/tests/test_lstm_op.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _reverse(x, lod):
110110

111111
class TestLstmOp(OpTest):
112112
def set_argument(self):
113-
self.lod = [[0, 2, 6]]
113+
self.lod = [[0, 2, 5, 7]]
114114
self.D = 16
115115

116116
self.act_gate = 'sigmoid'
@@ -164,12 +164,13 @@ def test_check_grad(self):
164164
# TODO(qingqing) remove folowing two lines after the check_grad is refined.
165165
self.outputs['BatchGate'] = None
166166
self.outputs['BatchCellPreAct'] = None
167-
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])
167+
self.check_grad(
168+
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=0.02)
168169

169170

170171
class TestLstmOpHasNoInitial(TestLstmOp):
171172
def set_argument(self):
172-
self.lod = [[0, 2, 6]]
173+
self.lod = [[0, 2, 5, 7]]
173174
self.D = 16
174175

175176
self.act_gate = 'sigmoid'
@@ -182,7 +183,7 @@ def set_argument(self):
182183

183184
class TestLstmOpRerverse(TestLstmOp):
184185
def set_argument(self):
185-
self.lod = [[0, 2, 6]]
186+
self.lod = [[0, 2, 5, 7]]
186187
self.D = 16
187188

188189
self.act_gate = 'sigmoid'

0 commit comments

Comments
 (0)