Skip to content

Commit 92a2dfb

Browse files
co63ocmaxiaolong001
authored andcommitted
graph_pattern_detector.cc add check use_onednn [fluid_ops] (PaddlePaddle#74517)
1 parent 314ca22 commit 92a2dfb

File tree

2 files changed

+47
-14
lines changed

2 files changed

+47
-14
lines changed

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2342,7 +2342,9 @@ PDNode *patterns::QuantConv::operator()(const std::string &conv_type) {
23422342
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op(conv_type);
23432343
conv_op->assert_more([&](Node *node) {
23442344
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
2345-
"bfloat16";
2345+
"bfloat16" ||
2346+
node->Op()->GetAttrIfExists<std::string>("onednn_data_type") ==
2347+
"bfloat16";
23462348
});
23472349

23482350
quant_op->LinksFrom({quant_in}).LinksTo({conv_in});
@@ -3172,7 +3174,8 @@ PDNode *patterns::QuantizePlacement::operator()(
31723174
auto *op =
31733175
pattern->NewNode(op_repr())->assert_is_ops(quantize_enabled_op_types);
31743176
op->assert_more([&](Node *node) {
3175-
return node->Op()->GetAttrIfExists<bool>("use_mkldnn");
3177+
return node->Op()->GetAttrIfExists<bool>("use_mkldnn") ||
3178+
node->Op()->GetAttrIfExists<bool>("use_onednn");
31763179
});
31773180
return op;
31783181
}
@@ -3218,6 +3221,7 @@ PDNode *patterns::Bfloat16Placement::operator()(
32183221
auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types);
32193222
op->assert_more([&](Node *node) {
32203223
return node->Op()->GetAttrIfExists<bool>("use_mkldnn") ||
3224+
node->Op()->GetAttrIfExists<bool>("use_onednn") ||
32213225
node->Op()->Type() == "reshape2";
32223226
});
32233227
op->LinksFrom({op_in});
@@ -3227,25 +3231,35 @@ PDNode *patterns::Bfloat16Placement::operator()(
32273231
PDNode *patterns::OrphanedBfloat16::operator()() {
32283232
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
32293233
prev_op->assert_more([&](Node *node) {
3230-
bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type");
3231-
bool data_type_is_fp32 = node->Op()->GetAttrIfExists<std::string>(
3232-
"mkldnn_data_type") == "float32";
3234+
bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type") &&
3235+
!node->Op()->HasAttr("onednn_data_type");
3236+
bool data_type_is_fp32 =
3237+
node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
3238+
"float32" ||
3239+
node->Op()->GetAttrIfExists<std::string>("onednn_data_type") ==
3240+
"float32";
32333241
return data_type_is_missing || data_type_is_fp32;
32343242
});
32353243
auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput();
32363244

32373245
auto *op = pattern->NewNode(op_repr())->assert_is_op();
32383246
op->assert_more([&](Node *node) {
32393247
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
3240-
"bfloat16";
3248+
"bfloat16" ||
3249+
node->Op()->GetAttrIfExists<std::string>("onednn_data_type") ==
3250+
"bfloat16";
32413251
});
32423252
auto *op_out = pattern->NewNode(op_out_repr())->AsOutput();
32433253

32443254
auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
32453255
next_op->assert_more([&](Node *node) {
3246-
bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type");
3247-
bool data_type_is_fp32 = node->Op()->GetAttrIfExists<std::string>(
3248-
"mkldnn_data_type") == "float32";
3256+
bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type") &&
3257+
!node->Op()->HasAttr("onednn_data_type");
3258+
bool data_type_is_fp32 =
3259+
node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
3260+
"float32" ||
3261+
node->Op()->GetAttrIfExists<std::string>("onednn_data_type") ==
3262+
"float32";
32493263
return data_type_is_missing || data_type_is_fp32;
32503264
});
32513265

@@ -3258,14 +3272,17 @@ PDNode *patterns::OrphanedBfloat16::operator()() {
32583272
PDNode *patterns::UnsupportedBfloat16::operator()() {
32593273
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
32603274
prev_op->assert_more([&](Node *node) {
3261-
return node->Op()->HasAttr("mkldnn_data_type") == false;
3275+
return node->Op()->HasAttr("mkldnn_data_type") == false &&
3276+
node->Op()->HasAttr("onednn_data_type") == false;
32623277
});
32633278
auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput();
32643279

32653280
auto *op = pattern->NewNode(op_repr())->assert_is_op();
32663281
op->assert_more([&](Node *node) {
32673282
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
3268-
"bfloat16";
3283+
"bfloat16" ||
3284+
node->Op()->GetAttrIfExists<std::string>("onednn_data_type") ==
3285+
"bfloat16";
32693286
});
32703287
prev_op->LinksTo({prev_out});
32713288
op->LinksFrom({prev_out});
@@ -3276,7 +3293,9 @@ PDNode *patterns::Bloat16Ops::operator()() {
32763293
auto op = pattern->NewNode(op_repr())->assert_is_op();
32773294
op->assert_more([&](Node *node) {
32783295
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
3279-
"bfloat16";
3296+
"bfloat16" ||
3297+
node->Op()->GetAttrIfExists<std::string>("onednn_data_type") ==
3298+
"bfloat16";
32803299
});
32813300
return op;
32823301
}
@@ -3298,8 +3317,8 @@ PDNode *patterns::ONEDNNInPlace::operator()() {
32983317
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
32993318
auto next_output = pattern->NewNode(next_op_out_repr())->AsOutput();
33003319

3301-
// Check if op is MKL-DNN enabled
3302-
possible_inplace_op->assert_op_attr("use_mkldnn", true);
3320+
// Check if op is ONE-DNN enabled
3321+
possible_inplace_op->assert_op_attr_or("use_mkldnn", "use_onednn", true);
33033322

33043323
// linked structure
33053324
possible_inplace_op->LinksTo({output});

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,20 @@ struct PDNode {
168168
return this;
169169
}
170170

171+
template <typename T>
172+
PDNode* assert_op_attr_or(const std::string& attr_name1,
173+
const std::string& attr_name2,
174+
const T& attr) {
175+
asserts_.emplace_back([=](Node* x) {
176+
return x && x->IsOp() &&
177+
((x->Op()->HasAttr(attr_name1) &&
178+
PADDLE_GET_CONST(T, x->Op()->GetAttr(attr_name1)) == attr) ||
179+
(x->Op()->HasAttr(attr_name2) &&
180+
PADDLE_GET_CONST(T, x->Op()->GetAttr(attr_name2)) == attr));
181+
});
182+
return this;
183+
}
184+
171185
private:
172186
PDNode(PDPattern* pattern,
173187
const std::string& name = "",

0 commit comments

Comments
 (0)