Skip to content

Commit 6efdea8

Browse files
committed
1. add shuffle_channel_detect
1 parent 8121b3e commit 6efdea8

File tree

4 files changed

+51
-0
lines changed

4 files changed

+51
-0
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ pass_library(sync_batch_norm_pass base)
7070
pass_library(runtime_context_cache_pass base)
7171
pass_library(quant_conv2d_dequant_fuse_pass inference)
7272
pass_library(fillconstant_elementwisemul_fuse inference)
73+
pass_library(shuffle_channel_detect_pass inference)
7374

7475
if(ANAKIN_FOUND)
7576
pass_library(simplify_anakin_priorbox_detection_out_pass inference)

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
17061706
}
17071707
}
17081708

1709+
void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
1710+
auto reshape1_op =
1711+
pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2");
1712+
1713+
auto reshape1_out = pattern->NewNode(reshape1_out_repr())
1714+
->assert_is_op_output("reshape2", "Out")
1715+
->assert_is_op_input("transpose2")
1716+
->AsIntermediate();
1717+
1718+
auto transpose_op =
1719+
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
1720+
1721+
auto transpose_out = pattern->NewNode(transpose_out_repr())
1722+
->assert_is_op_output("transpose2", "Out")
1723+
->assert_is_op_input("reshape2")
1724+
->AsIntermediate();
1725+
1726+
auto reshape2_op =
1727+
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
1728+
auto reshape2_out = pattern->NewNode(reshape2_out_repr())
1729+
->assert_is_op_output("reshape2", "Out")
1730+
->AsOutput();
1731+
1732+
reshape1_op->LinksFrom({reshape1_in});
1733+
reshape1_out->LinksFrom({reshape1_op});
1734+
transpose_op->LinksFrom({reshape1_out});
1735+
transpose_out->LinksFrom({transpose_op});
1736+
reshape2_op->LinksFrom({transpose_out});
1737+
reshape2_out->LinksFrom({reshape2_op});
1738+
}
1739+
17091740
} // namespace ir
17101741
} // namespace framework
17111742
} // namespace paddle

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase {
892892
}
893893
};
894894

895+
struct ShuffleChannelPattern : public PatternBase {
896+
ShuffleChannelPattern(PDPattern* pattern, const std::string& name_scope)
897+
: PatternBase(pattern, name_scope, "shufflechannel_pattern") {}
898+
899+
void operator()(PDNode* reshape1_in);
900+
901+
PATTERN_DECL_NODE(reshape1_op);
902+
PATTERN_DECL_NODE(reshape1_out);
903+
904+
PATTERN_DECL_NODE(transpose_op);
905+
PATTERN_DECL_NODE(transpose_out);
906+
PATTERN_DECL_NODE(reshape2_op);
907+
PATTERN_DECL_NODE(reshape2_out);
908+
};
909+
895910
} // namespace patterns
896911

897912
// Link two ir::Nodes from each other.

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ const std::vector<std::string> kAnakinSubgraphPasses({
7979
"fc_fuse_pass", //
8080
"conv_elementwise_add_fuse_pass", //
8181
"fc_gru_fuse_pass", //
82+
"graph_viz_pass", //
83+
"shuffle_channel_detect_pass", //
84+
"graph_viz_pass", //
8285
"anakin_subgraph_pass", //
86+
"graph_viz_pass", //
8387
"fc_gru_fuse_pass", //
8488
});
8589

0 commit comments

Comments
 (0)