Skip to content

Commit ef79443

Browse files
authored
Merge branch 'develop' into split_api
2 parents 7ea4fa5 + ab55c08 commit ef79443

36 files changed

+834
-67
lines changed

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ int FCLstmFusePass::BuildFusion(Graph* graph,
195195
Node* cell,
196196
Node* xx,
197197
Node* fc_bias,
198-
const bool use_mkldnn) {
198+
const bool use_onednn) {
199199
OpDesc op_desc;
200200
op_desc.SetType("fusion_lstm");
201201
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
@@ -235,7 +235,7 @@ int FCLstmFusePass::BuildFusion(Graph* graph,
235235
op_desc.SetOutput("XX", {xx->Name()});
236236
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
237237
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
238-
op_desc.SetAttr("use_mkldnn", use_mkldnn);
238+
op_desc.SetAttr("use_onednn", use_onednn);
239239
// TODO(TJ): get from attr
240240
op_desc.SetAttr("use_seq", true);
241241

@@ -300,8 +300,9 @@ int FCLstmFusePass::BuildFusion(Graph* graph,
300300
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
301301
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
302302
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
303-
const bool use_mkldnn =
304-
(mul->Op()->GetAttrIfExists<bool>("use_mkldnn") &&
303+
const bool use_onednn =
304+
((mul->Op()->GetAttrIfExists<bool>("use_mkldnn") ||
305+
mul->Op()->GetAttrIfExists<bool>("use_onednn")) &&
305306
lstm->Op()->GetAttrIfExists<std::string>("gate_activation") ==
306307
"sigmoid" &&
307308
lstm->Op()->GetAttrIfExists<std::string>("cell_activation") ==
@@ -323,7 +324,7 @@ int FCLstmFusePass::BuildFusion(Graph* graph,
323324
Cell,
324325
fc_out,
325326
fc_bias,
326-
use_mkldnn);
327+
use_onednn);
327328
// Remove unneeded nodes.
328329
std::unordered_set<const Node*> marked_nodes(
329330
{mul, lstm, elementwise_add, mul_out, BatchGate, BatchCellPreAct});
@@ -339,7 +340,7 @@ int FCLstmFusePass::BuildFusion(Graph* graph,
339340
Cell,
340341
fc_out,
341342
nullptr,
342-
use_mkldnn);
343+
use_onednn);
343344
// Remove unneeded nodes.
344345
std::unordered_set<const Node*> marked_nodes(
345346
{mul, lstm, BatchGate, BatchCellPreAct});

paddle/fluid/framework/ir/fuse_pass_base.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,16 @@ void FusePassBase::AddStatis(int count_of_fused) const {
5858
FuseOptions FusePassBase::FindFuseOption(const Node& node1,
5959
const Node& node2) const {
6060
#ifdef PADDLE_WITH_DNNL
61-
bool node1_onednn = node1.Op()->HasAttr("use_mkldnn") &&
62-
PADDLE_GET_CONST(bool, node1.Op()->GetAttr("use_mkldnn"));
63-
bool node2_onednn = node2.Op()->HasAttr("use_mkldnn") &&
64-
PADDLE_GET_CONST(bool, node2.Op()->GetAttr("use_mkldnn"));
61+
bool node1_onednn =
62+
(node1.Op()->HasAttr("use_mkldnn") &&
63+
PADDLE_GET_CONST(bool, node1.Op()->GetAttr("use_mkldnn"))) ||
64+
(node1.Op()->HasAttr("use_onednn") &&
65+
PADDLE_GET_CONST(bool, node1.Op()->GetAttr("use_onednn")));
66+
bool node2_onednn =
67+
(node2.Op()->HasAttr("use_mkldnn") &&
68+
PADDLE_GET_CONST(bool, node2.Op()->GetAttr("use_mkldnn"))) ||
69+
(node2.Op()->HasAttr("use_onednn") &&
70+
PADDLE_GET_CONST(bool, node2.Op()->GetAttr("use_onednn")));
6571
if (node1_onednn && node2_onednn)
6672
return FUSE_ONEDNN;
6773
else if (!node1_onednn && !node2_onednn)

paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,9 @@ void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
444444
if (matmul_v2_op->Op()->HasAttr("use_mkldnn")) {
445445
desc.SetAttr("use_mkldnn", matmul_v2_op->Op()->GetAttr("use_mkldnn"));
446446
}
447+
if (matmul_v2_op->Op()->HasAttr("use_onednn")) {
448+
desc.SetAttr("use_onednn", matmul_v2_op->Op()->GetAttr("use_onednn"));
449+
}
447450
if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
448451
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
449452
desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale"));

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3289,7 +3289,7 @@ PDNode *patterns::UnsupportedBfloat16::operator()() {
32893289
return op;
32903290
}
32913291

3292-
PDNode *patterns::Bloat16Ops::operator()() {
3292+
PDNode *patterns::Bfloat16Ops::operator()() {
32933293
auto op = pattern->NewNode(op_repr())->assert_is_op();
32943294
op->assert_more([&](Node *node) {
32953295
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,8 +1774,8 @@ struct UnsupportedBfloat16 : public PatternBase {
17741774
PATTERN_DECL_NODE(op);
17751775
};
17761776

1777-
struct Bloat16Ops : public PatternBase {
1778-
Bloat16Ops(PDPattern* pattern, const std::string& name_scope)
1777+
struct Bfloat16Ops : public PatternBase {
1778+
Bfloat16Ops(PDPattern* pattern, const std::string& name_scope)
17791779
: PatternBase(pattern, name_scope, "many_bfloat16_ops") {}
17801780

17811781
PDNode* operator()();

paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,11 @@ void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const {
249249
int dequantize_counter = 0;
250250

251251
GraphPatternDetector gpd;
252-
patterns::Bloat16Ops Bloat16Ops{gpd.mutable_pattern(), "Bloat16Ops"};
253-
Bloat16Ops();
252+
patterns::Bfloat16Ops Bfloat16Ops{gpd.mutable_pattern(), "Bfloat16Ops"};
253+
Bfloat16Ops();
254254
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
255255
Graph* graph) {
256-
GET_IR_NODE_FROM_SUBGRAPH(op, op, Bloat16Ops);
256+
GET_IR_NODE_FROM_SUBGRAPH(op, op, Bfloat16Ops);
257257

258258
Quantizer quantizer(graph, op);
259259
quantizer.AddQuantOps();

paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass_tester.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void SetOp(ProgramDesc* prog,
2828
const std::string& onednn_data_type = "float32") {
2929
auto* op = prog->MutableBlock(0)->AppendOp();
3030
op->SetType(type);
31-
op->SetAttr("use_mkldnn", use_onednn);
31+
op->SetAttr("use_onednn", use_onednn);
3232
op->SetAttr("name", name);
3333

3434
if (type == "conv2d") {

paddle/fluid/framework/ir/onednn/interpolate_onednn_pass.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ void InterpolateOneDNNPass::ApplyImpl(ir::Graph* graph) const {
3131
PADDLE_ENFORCE_NOT_NULL(graph,
3232
common::errors::InvalidArgument(
3333
"Pointer to graph argument should not be NULL."));
34-
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) {
34+
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn")) &&
35+
!(graph->Has("use_onednn") && graph->Get<bool>("use_onednn"))) {
3536
VLOG(3) << "Do not handle interpolate_onednn_pass";
3637
return;
3738
}
@@ -53,7 +54,7 @@ void InterpolateOneDNNPass::ApplyImpl(ir::Graph* graph) const {
5354
interpolate_op_types.end(),
5455
node->Name()) != interpolate_op_types.end()) {
5556
auto* op_desc = node->Op();
56-
op_desc->SetAttr("use_mkldnn", true);
57+
op_desc->SetAttr("use_onednn", true);
5758
++found_count;
5859
}
5960
}

paddle/fluid/framework/ir/onednn/multi_gru_fuse_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ MultiGRUFusePass::MultiGRUFusePass() {
186186
.AddAttr("origin_mode")
187187
.IsType<bool>()
188188
.End()
189-
.AddAttr("use_mkldnn")
189+
.AddAttr("use_onednn")
190190
.IsType<bool>()
191191
.End()
192192
.AddAttr("mkldnn_data_type")

paddle/fluid/framework/op_desc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ class TEST_API OpDesc {
235235
// attribute name => all original attrs
236236
AttributeMap attrs_;
237237
// runtime_attrs_ contains the attributes which used for dispatching kernel
238-
// (use_mkldnn, use_cudnn, ...) or passing additional configuration for
238+
// (use_onednn, use_cudnn, ...) or passing additional configuration for
239239
// special heterogeneous kernel (workspace_size_MB, ...).
240240
// The attributes in runtime_attrs_ are set by framework (such as PASS),
241241
// and not in the python api.

0 commit comments

Comments
 (0)