Skip to content

Commit d08822d

Browse files
committed
Fix
1 parent c128f16 commit d08822d

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3318,12 +3318,7 @@ PDNode *patterns::MKLDNNInPlace::operator()() {
33183318
auto next_output = pattern->NewNode(next_op_out_repr())->AsOutput();
33193319

33203320
// Check if op is ONE-DNN enabled
3321-
if (possible_inplace_op->HasAttr("use_mkldnn")) {
3322-
possible_inplace_op->assert_op_attr("use_mkldnn", true);
3323-
}
3324-
if (possible_inplace_op->HasAttr("use_onednn")) {
3325-
possible_inplace_op->assert_op_attr("use_onednn", true);
3326-
}
3321+
possible_inplace_op->assert_op_attr_or("use_mkldnn", "use_onednn", true);
33273322

33283323
// linked structure
33293324
possible_inplace_op->LinksTo({output});

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,19 @@ struct PDNode {
168168
return this;
169169
}
170170

171+
PDNode* assert_op_attr_or(const std::string& attr_name1,
172+
const std::string& attr_name2,
173+
const T& attr) {
174+
asserts_.emplace_back([=](Node* x) {
175+
return x && x->IsOp() &&
176+
((x->Op()->HasAttr(attr_name1) &&
177+
PADDLE_GET_CONST(T, x->Op()->GetAttr(attr_name1)) == attr) ||
178+
(x->Op()->HasAttr(attr_name2) &&
179+
PADDLE_GET_CONST(T, x->Op()->GetAttr(attr_name2)) == attr));
180+
});
181+
return this;
182+
}
183+
171184
private:
172185
PDNode(PDPattern* pattern,
173186
const std::string& name = "",

0 commit comments

Comments
 (0)