diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 3471e9e1f6efe1..5237a69f64d136 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2342,7 +2342,9 @@ PDNode *patterns::QuantConv::operator()(const std::string &conv_type) { auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op(conv_type); conv_op->assert_more([&](Node *node) { return node->Op()->GetAttrIfExists("mkldnn_data_type") == - "bfloat16"; + "bfloat16" || + node->Op()->GetAttrIfExists("onednn_data_type") == + "bfloat16"; }); quant_op->LinksFrom({quant_in}).LinksTo({conv_in}); @@ -3172,7 +3174,8 @@ PDNode *patterns::QuantizePlacement::operator()( auto *op = pattern->NewNode(op_repr())->assert_is_ops(quantize_enabled_op_types); op->assert_more([&](Node *node) { - return node->Op()->GetAttrIfExists("use_mkldnn"); + return node->Op()->GetAttrIfExists("use_mkldnn") || + node->Op()->GetAttrIfExists("use_onednn"); }); return op; } @@ -3218,6 +3221,7 @@ PDNode *patterns::Bfloat16Placement::operator()( auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types); op->assert_more([&](Node *node) { return node->Op()->GetAttrIfExists("use_mkldnn") || + node->Op()->GetAttrIfExists("use_onednn") || node->Op()->Type() == "reshape2"; }); op->LinksFrom({op_in}); @@ -3227,9 +3231,13 @@ PDNode *patterns::Bfloat16Placement::operator()( PDNode *patterns::OrphanedBfloat16::operator()() { auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); prev_op->assert_more([&](Node *node) { - bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type"); - bool data_type_is_fp32 = node->Op()->GetAttrIfExists( - "mkldnn_data_type") == "float32"; + bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type") && + !node->Op()->HasAttr("onednn_data_type"); + bool data_type_is_fp32 = + node->Op()->GetAttrIfExists("mkldnn_data_type") == + "float32" || + node->Op()->GetAttrIfExists("onednn_data_type") == + "float32"; return data_type_is_missing || data_type_is_fp32; }); auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput(); @@ -3237,15 +3245,21 @@ PDNode *patterns::OrphanedBfloat16::operator()() { auto *op = pattern->NewNode(op_repr())->assert_is_op(); op->assert_more([&](Node *node) { return node->Op()->GetAttrIfExists("mkldnn_data_type") == - "bfloat16"; + "bfloat16" || + node->Op()->GetAttrIfExists("onednn_data_type") == + "bfloat16"; }); auto *op_out = pattern->NewNode(op_out_repr())->AsOutput(); auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); next_op->assert_more([&](Node *node) { - bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type"); - bool data_type_is_fp32 = node->Op()->GetAttrIfExists( - "mkldnn_data_type") == "float32"; + bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type") && + !node->Op()->HasAttr("onednn_data_type"); + bool data_type_is_fp32 = + node->Op()->GetAttrIfExists("mkldnn_data_type") == + "float32" || + node->Op()->GetAttrIfExists("onednn_data_type") == + "float32"; return data_type_is_missing || data_type_is_fp32; }); @@ -3258,14 +3272,17 @@ PDNode *patterns::OrphanedBfloat16::operator()() { PDNode *patterns::UnsupportedBfloat16::operator()() { auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); prev_op->assert_more([&](Node *node) { - return node->Op()->HasAttr("mkldnn_data_type") == false; + return node->Op()->HasAttr("mkldnn_data_type") == false && + node->Op()->HasAttr("onednn_data_type") == false; }); auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput(); auto *op = pattern->NewNode(op_repr())->assert_is_op(); op->assert_more([&](Node *node) { return node->Op()->GetAttrIfExists("mkldnn_data_type") == - "bfloat16"; + "bfloat16" || + node->Op()->GetAttrIfExists("onednn_data_type") == + "bfloat16"; }); prev_op->LinksTo({prev_out}); op->LinksFrom({prev_out}); @@ -3276,7 +3293,9 @@ PDNode *patterns::Bloat16Ops::operator()() { auto op = pattern->NewNode(op_repr())->assert_is_op(); op->assert_more([&](Node *node) { return node->Op()->GetAttrIfExists("mkldnn_data_type") == - "bfloat16"; + "bfloat16" || + node->Op()->GetAttrIfExists("onednn_data_type") == + "bfloat16"; }); return op; } @@ -3298,8 +3317,8 @@ PDNode *patterns::MKLDNNInPlace::operator()() { auto next_op = pattern->NewNode(next_op_repr())->assert_is_op(); auto next_output = pattern->NewNode(next_op_out_repr())->AsOutput(); - // Check if op is MKL-DNN enabled - possible_inplace_op->assert_op_attr("use_mkldnn", true); + // Check if op is ONE-DNN enabled + possible_inplace_op->assert_op_attr_or("use_mkldnn", "use_onednn", true); // linked structure possible_inplace_op->LinksTo({output}); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 6110ac20214401..eb9d3f8be4c3e5 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -168,6 +168,20 @@ struct PDNode { return this; } + template + PDNode* assert_op_attr_or(const std::string& attr_name1, + const std::string& attr_name2, + const T& attr) { + asserts_.emplace_back([=](Node* x) { + return x && x->IsOp() && + ((x->Op()->HasAttr(attr_name1) && + PADDLE_GET_CONST(T, x->Op()->GetAttr(attr_name1)) == attr) || + (x->Op()->HasAttr(attr_name2) && + PADDLE_GET_CONST(T, x->Op()->GetAttr(attr_name2)) == attr)); + }); + return this; + } + private: PDNode(PDPattern* pattern, const std::string& name = "",