Skip to content

Commit d292ad8

Browse files
authored
Merge pull request #12958 from tensor-tang/refine/op/fusion_lstm
refine fusion lstm
2 parents 4fcc293 + c488ee9 commit d292ad8

File tree

2 files changed

+185
-33
lines changed

2 files changed

+185
-33
lines changed

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 149 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/fusion_lstm_op.h"
1616
#include <string>
1717
#include "paddle/fluid/operators/math/blas.h"
18+
#include "paddle/fluid/operators/math/cpu_vec.h"
1819
#include "paddle/fluid/operators/math/detail/activation_functions.h"
1920
#include "paddle/fluid/operators/math/fc_compute.h"
2021
#include "paddle/fluid/operators/math/lstm_compute.h"
2122
#include "paddle/fluid/operators/math/sequence2batch.h"
23+
#include "paddle/fluid/platform/cpu_info.h"
24+
25+
DEFINE_bool(seq_mode, true, "Use sequence mode");
2226

2327
namespace paddle {
2428
namespace operators {
@@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
98102
ctx->ShareLoD("X", "Hidden");
99103
ctx->ShareLoD("X", "Cell");
100104

101-
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
105+
int xx_width;
106+
if (FLAGS_seq_mode) {
107+
xx_width = wx_dims[1];
108+
} else {
109+
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
110+
}
102111
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
103112
ctx->ShareLoD("X", "XX");
104113
}
@@ -205,10 +214,138 @@ inline void ReorderInitState(const DeviceContext& ctx,
205214
row_shuffle(ctx, src, index_lod, dst, indexed_src);
206215
}
207216

208-
template <typename DeviceContext, typename T>
217+
template <typename T>
209218
class FuisonLSTMKernel : public framework::OpKernel<T> {
210219
public:
211-
void Compute(const framework::ExecutionContext& ctx) const override {
220+
void SeqCompute(const framework::ExecutionContext& ctx) const {
221+
using DeviceContext = paddle::platform::CPUDeviceContext;
222+
auto* x = ctx.Input<LoDTensor>("X");
223+
auto* h0 = ctx.Input<Tensor>("H0");
224+
auto* c0 = ctx.Input<Tensor>("C0");
225+
auto* wx = ctx.Input<Tensor>("WeightX");
226+
auto* wh = ctx.Input<Tensor>("WeightH");
227+
auto* bias = ctx.Input<Tensor>("Bias");
228+
229+
auto* xx = ctx.Output<LoDTensor>("XX");
230+
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
231+
auto* cell_out = ctx.Output<LoDTensor>("Cell");
232+
bool is_reverse = ctx.Attr<bool>("is_reverse");
233+
234+
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
235+
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
236+
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
237+
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
238+
if (platform::jit::MayIUse(platform::jit::avx)) {
239+
math::VecActivations<T, platform::jit::avx> act_functor;
240+
act_gate = act_functor(act_gate_str);
241+
act_cell = act_functor(act_cell_str);
242+
act_cand = act_functor(act_cand_str);
243+
} else {
244+
math::VecActivations<T, platform::jit::isa_any> act_functor;
245+
act_gate = act_functor(act_gate_str);
246+
act_cell = act_functor(act_cell_str);
247+
act_cand = act_functor(act_cand_str);
248+
}
249+
250+
auto x_lod = x->lod();
251+
auto x_dims = x->dims(); // T x M
252+
auto wh_dims = wh->dims(); // D x 4D
253+
const int total_T = x_dims[0];
254+
const int N = x_lod[0].size() - 1; // batch size
255+
const int M = x_dims[1]; // x frame size
256+
const int D = wh_dims[0];
257+
const int D2 = D * 2;
258+
const int D3 = D * 3;
259+
const int D4 = wh_dims[1];
260+
261+
const T* x_data = x->data<T>();
262+
const T* h0_data = h0 ? h0->data<T>() : NULL;
263+
const T* c0_data = c0 ? c0->data<T>() : NULL;
264+
const T* wx_data = wx->data<T>();
265+
const T* wh_data = wh->data<T>();
266+
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
267+
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
268+
T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
269+
270+
auto blas = math::GetBlas<DeviceContext, T>(ctx);
271+
math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data,
272+
xx_data, bias->data<T>());
273+
int xx_offset = D4;
274+
int gate_offset = D;
275+
if (is_reverse) {
276+
const int offset = (total_T - 1) * D;
277+
xx_data = xx_data + offset * 4;
278+
hidden_out_data = hidden_out_data + offset;
279+
cell_out_data = cell_out_data + offset;
280+
xx_offset = -D4;
281+
gate_offset = -D;
282+
}
283+
284+
auto move_step = [&]() {
285+
xx_data = xx_data + xx_offset;
286+
hidden_out_data = hidden_out_data + gate_offset;
287+
cell_out_data = cell_out_data + gate_offset;
288+
};
289+
290+
for (int i = 0; i < N; ++i) {
291+
int bid = is_reverse ? N - 1 - i : i;
292+
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
293+
const T* prev_cell_data = NULL;
294+
const T* prev_hidden_data = NULL;
295+
int tstart = 0;
296+
if (h0_data) {
297+
prev_hidden_data = h0_data + bid * D;
298+
prev_cell_data = c0_data + bid * D;
299+
} else {
300+
// W_ch, W_ih, W_fh, W_oh
301+
act_gate(D3, xx_data + D, xx_data + D);
302+
act_cand(D, xx_data, xx_data);
303+
// cell out= input*tilde
304+
blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
305+
// hidden out= act_state(cellout) * outgate
306+
act_cell(D, cell_out_data, xx_data + D2);
307+
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
308+
309+
// prev
310+
prev_hidden_data = hidden_out_data;
311+
prev_cell_data = cell_out_data;
312+
tstart = 1;
313+
314+
move_step();
315+
}
316+
for (int step = tstart; step < seq_len; ++step) {
317+
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
318+
prev_hidden_data, D, wh_data, D4, static_cast<T>(1), xx_data,
319+
D4);
320+
321+
// W_ch, W_ih, W_fh, W_oh
322+
act_gate(D3, xx_data + D, xx_data + D);
323+
act_cand(D, xx_data, xx_data);
324+
325+
// a = forget * prev_cell
326+
blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2);
327+
328+
// b = input * tilde
329+
blas.VMUL(D, xx_data, xx_data + D, xx_data + D);
330+
331+
// cell out= a+b
332+
blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
333+
334+
// hidden out= act_state(cellout) * outgate
335+
act_cell(D, cell_out_data, xx_data + D2);
336+
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
337+
338+
// prev
339+
prev_hidden_data = hidden_out_data;
340+
prev_cell_data = cell_out_data;
341+
342+
move_step();
343+
}
344+
}
345+
}
346+
347+
void BatchCompute(const framework::ExecutionContext& ctx) const {
348+
using DeviceContext = platform::CPUDeviceContext;
212349
auto* x = ctx.Input<LoDTensor>("X");
213350
auto* wx = ctx.Input<Tensor>("WeightX");
214351
auto* wh = ctx.Input<Tensor>("WeightH");
@@ -339,6 +476,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
339476
// restore the output cell state in LoDTensor from the batch cell
340477
to_seq(dev_ctx, batch_cell, cell_out);
341478
}
479+
void Compute(const framework::ExecutionContext& ctx) const override {
480+
if (FLAGS_seq_mode) {
481+
SeqCompute(ctx);
482+
} else {
483+
BatchCompute(ctx);
484+
}
485+
}
342486
};
343487

344488
} // namespace operators
@@ -348,7 +492,5 @@ namespace ops = paddle::operators;
348492
REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker,
349493
paddle::framework::DefaultGradOpDescMaker<true>);
350494

351-
REGISTER_OP_CPU_KERNEL(
352-
fusion_lstm,
353-
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, float>,
354-
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, double>);
495+
REGISTER_OP_CPU_KERNEL(fusion_lstm, ops::FuisonLSTMKernel<float>,
496+
ops::FuisonLSTMKernel<double>);

python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ def fusion_lstm(
4343
act_cell, act_cand)
4444

4545

46-
class TestLstmOp(OpTest):
47-
def set_argument(self):
48-
self.lod = [[2, 3, 2]]
46+
class TestFusionLSTMOp(OpTest):
47+
def set_conf(self):
48+
pass
4949

5050
def setUp(self):
5151
self.op_type = 'fusion_lstm'
52-
self.lod = [[2, 3, 2]]
52+
self.lod = [[2, 3, 5, 4]]
5353
self.M = 8
5454
self.D = 16
5555
self.has_initial_state = False
@@ -58,33 +58,33 @@ def setUp(self):
5858
self.act_cell = 'tanh'
5959
self.act_cand = 'tanh'
6060
self.use_peepholes = False
61-
self.set_argument()
61+
self.set_conf()
6262

6363
T = sum(self.lod[0])
6464
bs = len(self.lod[0])
6565

66-
x = np.random.normal(size=(T, self.M)).astype('float64')
66+
x = np.random.normal(size=(T, self.M)).astype('float32')
6767
if self.has_initial_state:
68-
h0 = np.random.normal(size=(bs, self.D)).astype('float64')
69-
c0 = np.random.normal(size=(bs, self.D)).astype('float64')
68+
h0 = np.random.normal(size=(bs, self.D)).astype('float32')
69+
c0 = np.random.normal(size=(bs, self.D)).astype('float32')
7070
else:
71-
h0 = np.zeros((bs, self.D)).astype('float64')
72-
c0 = np.zeros((bs, self.D)).astype('float64')
71+
h0 = np.zeros((bs, self.D)).astype('float32')
72+
c0 = np.zeros((bs, self.D)).astype('float32')
7373

74-
wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
74+
wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32')
7575

7676
if self.use_peepholes:
77-
b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
77+
b = np.random.normal(size=(1, 7 * self.D)).astype('float32')
7878
else:
79-
b = np.random.normal(size=(1, 4 * self.D)).astype('float64')
79+
b = np.random.normal(size=(1, 4 * self.D)).astype('float32')
8080
w_b = np.copy(b[:, 0:4 * self.D])
8181
w_c = b[:, 4 * self.D:] if self.use_peepholes else None
8282

8383
# this is the weight of fc
84-
wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float64')
84+
wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32')
8585
# this is the bias of fc
8686
# and it should be manually added into the bias of this fusion LSTM
87-
bx = np.random.normal(size=(1, 4 * self.D)).astype('float64')
87+
bx = np.random.normal(size=(1, 4 * self.D)).astype('float32')
8888
b[0, 0:4 * self.D] += bx[0, :]
8989
h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c,
9090
self.is_reverse, ACTIVATION[self.act_gate],
@@ -114,35 +114,45 @@ def setUp(self):
114114
}
115115

116116
def test_check_output(self):
117-
self.check_output(atol=1e-8)
117+
self.check_output()
118118

119119

120-
class TestLstmOpInitReverse(TestLstmOp):
121-
def set_argument(self):
120+
class TestFusionLSTMOpInit(TestFusionLSTMOp):
121+
def set_conf(self):
122+
self.has_initial_state = True
123+
124+
125+
class TestFusionLSTMOpReverse(TestFusionLSTMOp):
126+
def set_conf(self):
127+
self.is_reverse = True
128+
129+
130+
class TestFusionLSTMOpInitReverse(TestFusionLSTMOp):
131+
def set_conf(self):
122132
self.has_initial_state = True
123133
self.is_reverse = True
124134

125135

126-
class TestLstmOpMD1(TestLstmOp):
127-
def set_argument(self):
136+
class TestFusionLSTMOpMD1(TestFusionLSTMOp):
137+
def set_conf(self):
128138
self.M = 36
129139
self.D = 8
130140

131141

132-
class TestLstmOpMD2(TestLstmOp):
133-
def set_argument(self):
142+
class TestFusionLSTMOpMD2(TestFusionLSTMOp):
143+
def set_conf(self):
134144
self.M = 8
135145
self.D = 8
136146

137147

138-
class TestLstmOpMD3(TestLstmOp):
139-
def set_argument(self):
148+
class TestFusionLSTMOpMD3(TestFusionLSTMOp):
149+
def set_conf(self):
140150
self.M = 15
141151
self.D = 3
142152

143153

144-
class TestLstmOpBS1(TestLstmOp):
145-
def set_argument(self):
154+
class TestFusionLSTMOpBS1(TestFusionLSTMOp):
155+
def set_conf(self):
146156
self.lod = [[3]]
147157
self.D = 16
148158

0 commit comments

Comments
 (0)