@@ -519,50 +519,41 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
519
519
520
520
PDNode* patterns::FC (PDPattern* pattern, const std::string& name_scope,
521
521
PDNode* x, bool with_bias) {
522
- // Create Operators
523
- PDNode* elementwise_add_op{nullptr };
522
+ // mul op
524
523
auto * mul_op = pattern->NewNode (name_scope, " mul" )->assert_is_op (" mul" );
525
- if (with_bias) {
526
- elementwise_add_op = pattern->NewNode (name_scope, " elementwise_add" )
527
- ->assert_is_op (" elementwise_add" );
528
- }
529
- // Create variables
530
- // w
531
524
auto * mul_weight_var = pattern->NewNode (name_scope, " w" )
532
525
->AsInput ()
533
526
->assert_is_persistable_var ()
534
- ->assert_is_op_nth_input (" mul" , " Y" , 0 );
535
- PDNode* mul_out_var{nullptr };
527
+ ->assert_is_op_input (" mul" , " Y" );
528
+
529
+ PDNode* fc_out{nullptr };
536
530
if (with_bias) {
531
+ PDNode* elementwise_add_op{nullptr };
532
+ PDNode *mul_out_var{nullptr }, *bias{nullptr };
533
+ elementwise_add_op = pattern->NewNode (name_scope, " elementwise_add" )
534
+ ->assert_is_op (" elementwise_add" );
537
535
// intermediate variable, will be removed in the IR after fuse.
538
536
mul_out_var = pattern->NewNode (name_scope, " mul_out" )
539
537
->AsIntermediate ()
540
538
->assert_is_only_output_of_op (" mul" )
541
- ->assert_is_op_input (" elementwise_add" );
542
- }
543
- PDNode *bias{nullptr }, *fc_out{nullptr };
544
- if (with_bias) {
539
+ ->assert_is_op_input (" elementwise_add" , " X" );
545
540
// bias
546
541
bias = pattern->NewNode (name_scope, " fc_bias" )
547
- ->assert_is_op_input (" elementwise_add" )
548
- ->AsInput ();
542
+ ->AsInput ()
543
+ ->assert_is_persistable_var ()
544
+ ->assert_is_op_input (" elementwise_add" , " Y" );
549
545
// output
550
546
fc_out = pattern->NewNode (name_scope, " fc_out" )
551
547
->AsOutput ()
552
- ->assert_is_op_output (" elementwise_add" );
548
+ ->assert_is_op_output (" elementwise_add" , " Out" );
549
+ mul_op->LinksFrom ({x, mul_weight_var}).LinksTo ({mul_out_var});
550
+ elementwise_add_op->LinksFrom ({mul_out_var, bias}).LinksTo ({fc_out});
553
551
} else {
554
552
fc_out = pattern->NewNode (name_scope, " fc_out" )
555
553
->AsOutput ()
556
- ->assert_is_op_output (" mul" );
557
- }
558
-
559
- if (with_bias) {
560
- mul_op->LinksFrom ({mul_weight_var, x}).LinksTo ({mul_out_var});
561
- elementwise_add_op->LinksFrom ({mul_out_var, bias}).LinksTo ({fc_out});
562
- } else {
554
+ ->assert_is_op_output (" mul" , " Out" );
563
555
mul_op->LinksFrom ({mul_weight_var, x}).LinksTo ({fc_out});
564
556
}
565
-
566
557
return fc_out;
567
558
}
568
559
@@ -609,6 +600,10 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope,
609
600
NEW_NODE (gru, BatchResetHiddenPrev, output);
610
601
NEW_NODE (gru, BatchHidden, output);
611
602
603
+ BatchGate->AsIntermediate ();
604
+ BatchResetHiddenPrev->AsIntermediate ();
605
+ BatchHidden->AsIntermediate ();
606
+
612
607
gru_op->LinksFrom ({x, Weight, Bias});
613
608
gru_op->LinksTo ({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden});
614
609
return Hidden;
0 commit comments