Skip to content

Commit 40f8456

Browse files
committed
refine fuse pattern and attr
test=develop
1 parent cbbacb2 commit 40f8456

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,6 @@ PDNode *PDNode::assert_is_op() {
349349
return this;
350350
}
351351

352-
// PDNode *PDNode::assert_op_attr() {
353-
// asserts_.emplace_back([](Node *x) { return x && x->IsOp(); });
354-
// return this;
355-
// }
356-
357352
PDNode *PDNode::assert_is_op(const std::string &op_type) {
358353
asserts_.emplace_back([op_type](Node *x) {
359354
return x && x->IsOp() && x->Op()->Type() == op_type;
@@ -770,10 +765,10 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
770765
paddle::framework::ir::PDNode *seqconv_input) {
771766
// Create Operators
772767
seqconv_input->assert_is_op_input("sequence_conv", "X");
773-
auto *seqconv_op =
774-
pattern->NewNode(seqconv_repr())->assert_is_op("sequence_conv");
775-
// ->assert_op_attr("paddingTrainable", false)
776-
// ->assert_op_attr("contextStride", 1)
768+
auto *seqconv_op = pattern->NewNode(seqconv_repr())
769+
->assert_is_op("sequence_conv")
770+
->assert_op_attr<bool>("paddingTrainable", false)
771+
->assert_op_attr<int>("contextStride", 1);
777772

778773
auto *eltadd_op =
779774
pattern->NewNode(eltadd_repr())->assert_is_op("elementwise_add");

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ struct PDNode {
128128
const std::unordered_set<std::string>& op_types,
129129
const std::string& argument, int nth);
130130

131+
template <typename T>
132+
PDNode* assert_op_attr(const std::string& attr_name, const T& attr) {
133+
asserts_.emplace_back([=](Node* x) {
134+
return x && x->IsOp() && x->Op()->HasAttr(attr_name) &&
135+
boost::get<T>(x->Op()->GetAttr(attr_name)) == attr;
136+
});
137+
return this;
138+
}
139+
131140
private:
132141
PDNode(PDPattern* pattern, const std::string& name = "",
133142
Type type = Type::kVar)

paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,13 @@ TEST(Analyzer_seq_conv1, fuse_statis) {
183183
SetConfig(&cfg);
184184
int num_ops;
185185
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
186-
GetFuseStatis(predictor.get(), &num_ops);
186+
187+
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
188+
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
189+
ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse"));
190+
EXPECT_EQ(fuse_statis.at("fc_fuse"), 2);
191+
EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6);
192+
EXPECT_EQ(num_ops, 32);
187193
}
188194

189195
// Compare result of NativeConfig and AnalysisConfig

0 commit comments

Comments
 (0)