Skip to content

Commit c9bd2d5

Browse files
committed
refine fc and gru pattern
1 parent 7eebb90 commit c9bd2d5

File tree

1 file changed

+20
-25
lines changed

1 file changed

+20
-25
lines changed

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -519,50 +519,41 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
519519

520520
PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope,
521521
PDNode* x, bool with_bias) {
522-
// Create Operators
523-
PDNode* elementwise_add_op{nullptr};
522+
// mul op
524523
auto* mul_op = pattern->NewNode(name_scope, "mul")->assert_is_op("mul");
525-
if (with_bias) {
526-
elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add")
527-
->assert_is_op("elementwise_add");
528-
}
529-
// Create variables
530-
// w
531524
auto* mul_weight_var = pattern->NewNode(name_scope, "w")
532525
->AsInput()
533526
->assert_is_persistable_var()
534-
->assert_is_op_nth_input("mul", "Y", 0);
535-
PDNode* mul_out_var{nullptr};
527+
->assert_is_op_input("mul", "Y");
528+
529+
PDNode* fc_out{nullptr};
536530
if (with_bias) {
531+
PDNode* elementwise_add_op{nullptr};
532+
PDNode *mul_out_var{nullptr}, *bias{nullptr};
533+
elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add")
534+
->assert_is_op("elementwise_add");
537535
// intermediate variable, will be removed in the IR after fuse.
538536
mul_out_var = pattern->NewNode(name_scope, "mul_out")
539537
->AsIntermediate()
540538
->assert_is_only_output_of_op("mul")
541-
->assert_is_op_input("elementwise_add");
542-
}
543-
PDNode *bias{nullptr}, *fc_out{nullptr};
544-
if (with_bias) {
539+
->assert_is_op_input("elementwise_add", "X");
545540
// bias
546541
bias = pattern->NewNode(name_scope, "fc_bias")
547-
->assert_is_op_input("elementwise_add")
548-
->AsInput();
542+
->AsInput()
543+
->assert_is_persistable_var()
544+
->assert_is_op_input("elementwise_add", "Y");
549545
// output
550546
fc_out = pattern->NewNode(name_scope, "fc_out")
551547
->AsOutput()
552-
->assert_is_op_output("elementwise_add");
548+
->assert_is_op_output("elementwise_add", "Out");
549+
mul_op->LinksFrom({x, mul_weight_var}).LinksTo({mul_out_var});
550+
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
553551
} else {
554552
fc_out = pattern->NewNode(name_scope, "fc_out")
555553
->AsOutput()
556-
->assert_is_op_output("mul");
557-
}
558-
559-
if (with_bias) {
560-
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({mul_out_var});
561-
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
562-
} else {
554+
->assert_is_op_output("mul", "Out");
563555
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out});
564556
}
565-
566557
return fc_out;
567558
}
568559

@@ -609,6 +600,10 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope,
609600
NEW_NODE(gru, BatchResetHiddenPrev, output);
610601
NEW_NODE(gru, BatchHidden, output);
611602

603+
BatchGate->AsIntermediate();
604+
BatchResetHiddenPrev->AsIntermediate();
605+
BatchHidden->AsIntermediate();
606+
612607
gru_op->LinksFrom({x, Weight, Bias});
613608
gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden});
614609
return Hidden;

0 commit comments

Comments
 (0)