Skip to content

Commit 9f2ccf5

Browse files
authored
Merge pull request #13237 from tensor-tang/refine/op/peephole
refine fusion lstm/peephole and fusion gru
2 parents 225ecee + 718033e commit 9f2ccf5

File tree

4 files changed

+254
-329
lines changed

4 files changed

+254
-329
lines changed

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
1415
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
1516
#include <string>
1617
#include "paddle/fluid/framework/lod_tensor.h"

paddle/fluid/operators/fusion_gru_op.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
3030
"Input(WeightX) of GRU should not be null.");
3131
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
3232
"Input(WeightH) of GRU should not be null.");
33-
3433
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
35-
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
36-
"Output(ReorderedH0) of GRU should not be null.");
37-
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
38-
"Output(BatchedInput) of GRU should not be null.");
39-
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
40-
"Output(BatchedOut) of GRU should not be null.");
4134
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
4235
"Output(Hidden) of GRU should not be null.");
4336

@@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
8073
}
8174
framework::DDim out_dims({x_dims[0], frame_size});
8275
ctx->SetOutputDim("Hidden", out_dims);
83-
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
84-
ctx->SetOutputDim("BatchedOut", out_dims);
8576
ctx->ShareLoD("X", "Hidden");
86-
8777
int xx_width;
8878
if (ctx->Attrs().Get<bool>("use_seq")) {
8979
xx_width = wx_dims[1];
9080
} else {
9181
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
82+
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
83+
"Output(ReorderedH0) of GRU should not be null.");
84+
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
85+
"Output(BatchedInput) of GRU should not be null.");
86+
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
87+
"Output(BatchedOut) of GRU should not be null.");
88+
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
89+
ctx->SetOutputDim("BatchedOut", out_dims);
9290
}
9391
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
9492
ctx->ShareLoD("X", "XX");

0 commit comments

Comments
 (0)