Skip to content

Commit 6938e6c

Browse files
authored
Merge pull request #13603 from tensor-tang/refine/peephole
refine peephole
2 parents 9b03d53 + 209e9c3 commit 6938e6c

File tree

2 files changed

+30
-19
lines changed

2 files changed

+30
-19
lines changed

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,12 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
7777
const std::string BatchedCellPreAct =
7878
patterns::UniqueKey("BatchedCellPreAct");
7979
const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
80+
const std::string CheckedCell = patterns::UniqueKey("CheckedCell");
8081

8182
scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
8283
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
8384
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
85+
scope->Var(CheckedCell)->GetMutable<framework::LoDTensor>();
8486

8587
op_desc.SetInput("H0", {});
8688
op_desc.SetInput("C0", {});
@@ -90,6 +92,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
9092
op_desc.SetOutput("BatchedGate", {BatchedGate});
9193
op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
9294
op_desc.SetOutput("BatchedInput", {BatchedInput});
95+
op_desc.SetOutput("CheckedCell", {CheckedCell});
9396
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
9497
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
9598
// TODO(TJ): get from attr

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,18 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
7676
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
7777
PADDLE_ENFORCE_EQ(b_dims[0], 1,
7878
"The first dimension of Input(Bias) should be 1.");
79-
PADDLE_ENFORCE_EQ(
80-
b_dims[1], (ctx->Attrs().Get<bool>("use_peepholes") ? 7 : 4) * frame_size,
81-
"The second dimension of Input(Bias) should be "
82-
"7 * %d if enable peepholes connection or"
83-
"4 * %d if disable peepholes",
84-
frame_size, frame_size);
79+
if (ctx->Attrs().Get<bool>("use_peepholes")) {
80+
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
81+
"The second dimension of Input(Bias) should be "
82+
"7 * %d if enable peepholes connection",
83+
frame_size);
84+
ctx->SetOutputDim("CheckedCell", {2, frame_size});
85+
} else {
86+
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
87+
"The second dimension of Input(Bias) should be "
88+
"4 * %d if disable peepholes",
89+
frame_size);
90+
}
8591

8692
framework::DDim out_dims({x_dims[0], frame_size});
8793
ctx->SetOutputDim("Hidden", out_dims);
@@ -173,6 +179,8 @@ void FusionLSTMOpMaker::Make() {
173179
AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate();
174180
AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate();
175181
AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate();
182+
AddOutput("CheckedCell", "(Tensor) (2 x D) only for peephole.")
183+
.AsIntermediate();
176184
AddAttr<bool>("use_peepholes",
177185
"(bool, defalut: True) "
178186
"whether to enable diagonal/peephole connections.")
@@ -250,19 +258,19 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
250258
const int D3 = D * 3; \
251259
const int D4 = wh_dims[1];
252260

253-
#define INIT_BASE_INPUT_DATAS \
254-
const T* x_data = x->data<T>(); \
255-
const T* wx_data = wx->data<T>(); \
256-
const T* wh_data = wh->data<T>(); \
257-
/* diagonal weight*/ \
258-
const T* wc_data = bias->data<T>() + D4; \
259-
/* for peephole only*/ \
260-
Tensor checked_cell; \
261-
T* checked_cell_data = nullptr; \
262-
auto place = ctx.GetPlace(); \
263-
if (use_peepholes) { \
264-
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
265-
checked_cell_data = checked_cell.mutable_data<T>({2, D}, place); \
261+
#define INIT_BASE_INPUT_DATAS \
262+
const T* x_data = x->data<T>(); \
263+
const T* wx_data = wx->data<T>(); \
264+
const T* wh_data = wh->data<T>(); \
265+
/* diagonal weight*/ \
266+
const T* wc_data = bias->data<T>() + D4; \
267+
/* for peephole only*/ \
268+
T* checked_cell_data = nullptr; \
269+
auto place = ctx.GetPlace(); \
270+
if (use_peepholes) { \
271+
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
272+
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
273+
checked_cell_data = checked_cell->mutable_data<T>(place); \
266274
}
267275

268276
/// Compute LSTM

0 commit comments

Comments
 (0)