Skip to content

Commit e61cf32

Browse files
committed
complete reverse seq
1 parent 1777cd0 commit e61cf32

File tree

2 files changed

+36
-22
lines changed

2 files changed

+36
-22
lines changed

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
229229
auto* xx = ctx.Output<LoDTensor>("XX");
230230
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
231231
auto* cell_out = ctx.Output<LoDTensor>("Cell");
232+
bool is_reverse = ctx.Attr<bool>("is_reverse");
232233

233234
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
234235
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
@@ -247,8 +248,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
247248
}
248249

249250
auto x_lod = x->lod();
250-
auto x_dims = x->dims(); // T x M
251-
auto wh_dims = wh->dims(); // D x 4D
251+
auto x_dims = x->dims(); // T x M
252+
auto wh_dims = wh->dims(); // D x 4D
253+
const int total_T = x_dims[0];
252254
const int N = x_lod[0].size() - 1; // batch size
253255
const int M = x_dims[1]; // x frame size
254256
const int D = wh_dims[0];
@@ -266,17 +268,34 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
266268
T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
267269

268270
auto blas = math::GetBlas<DeviceContext, T>(ctx);
269-
math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data,
271+
math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data,
270272
xx_data, bias->data<T>());
273+
int xx_offset = D4;
274+
int gate_offset = D;
275+
if (is_reverse) {
276+
const int offset = (total_T - 1) * D;
277+
xx_data = xx_data + offset * 4;
278+
hidden_out_data = hidden_out_data + offset;
279+
cell_out_data = cell_out_data + offset;
280+
xx_offset = -D4;
281+
gate_offset = -D;
282+
}
283+
284+
auto move_step = [&]() {
285+
xx_data = xx_data + xx_offset;
286+
hidden_out_data = hidden_out_data + gate_offset;
287+
cell_out_data = cell_out_data + gate_offset;
288+
};
271289

272290
for (int i = 0; i < N; ++i) {
273-
int seq_len = x_lod[0][i + 1] - x_lod[0][i];
291+
int bid = is_reverse ? N - 1 - i : i;
292+
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
274293
const T* prev_cell_data = NULL;
275294
const T* prev_hidden_data = NULL;
276295
int tstart = 0;
277296
if (h0_data) {
278-
prev_hidden_data = h0_data + i * D;
279-
prev_cell_data = c0_data + i * D;
297+
prev_hidden_data = h0_data + bid * D;
298+
prev_cell_data = c0_data + bid * D;
280299
} else {
281300
// W_ch, W_ih, W_fh, W_oh
282301
act_gate(D3, xx_data + D, xx_data + D);
@@ -292,10 +311,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
292311
prev_cell_data = cell_out_data;
293312
tstart = 1;
294313

295-
// move offset
296-
xx_data = xx_data + D4;
297-
hidden_out_data = hidden_out_data + D;
298-
cell_out_data = cell_out_data + D;
314+
move_step();
299315
}
300316
for (int step = tstart; step < seq_len; ++step) {
301317
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
@@ -323,10 +339,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
323339
prev_hidden_data = hidden_out_data;
324340
prev_cell_data = cell_out_data;
325341

326-
// move offset
327-
xx_data = xx_data + D4;
328-
hidden_out_data = hidden_out_data + D;
329-
cell_out_data = cell_out_data + D;
342+
move_step();
330343
}
331344
}
332345
}

python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,15 @@ def set_conf(self):
122122
self.has_initial_state = True
123123

124124

125-
# class TestFusionLSTMOpReverse(TestFusionLSTMOp):
126-
# def set_conf(self):
127-
# self.is_reverse = True
128-
129-
# class TestFusionLSTMOpInitReverse(TestFusionLSTMOp):
130-
# def set_conf(self):
131-
# self.has_initial_state = True
132-
# self.is_reverse = True
125+
class TestFusionLSTMOpReverse(TestFusionLSTMOp):
126+
def set_conf(self):
127+
self.is_reverse = True
128+
129+
130+
class TestFusionLSTMOpInitReverse(TestFusionLSTMOp):
131+
def set_conf(self):
132+
self.has_initial_state = True
133+
self.is_reverse = True
133134

134135

135136
class TestFusionLSTMOpMD1(TestFusionLSTMOp):

0 commit comments

Comments
 (0)