Skip to content

Commit e561677

Browse files
committed
Fix
1 parent e342884 commit e561677

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

paddle/fluid/framework/ir/onednn/activation_onednn_fuse_pass.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,21 @@ inline std::unordered_map<std::string, std::string> GetAttributeMap(
5959
inline void SetActivationAttrs(paddle::framework::OpDesc* fused_op,
6060
paddle::framework::OpDesc* act_op,
6161
const std::string& act_type) {
62-
if (fused_op->HasAttr("use_mkldnn")) {
62+
bool use_mkldnn = false;
63+
if (fused_op->HasAttr("use_mkldnn") && !fused_op->HasAttr("use_onednn")) {
6364
PADDLE_ENFORCE(PADDLE_GET_CONST(bool, fused_op->GetAttr("use_mkldnn")),
6465
common::errors::PreconditionNotMet(
65-
"oneDNN activation fuses require use_mkldnn=True"));
66+
"oneDNN activation fuses require use_onednn=True"));
67+
}
68+
if (fused_op->HasAttr("use_mkldnn")) {
69+
use_mkldnn = PADDLE_GET_CONST(bool, fused_op->GetAttr("use_mkldnn"));
70+
}
71+
if (!use_mkldnn && fused_op->HasAttr("use_onednn")) {
72+
PADDLE_ENFORCE(PADDLE_GET_CONST(bool, fused_op->GetAttr("use_onednn")),
73+
common::errors::PreconditionNotMet(
74+
"oneDNN activation fuses require use_onednn=True"));
6675
}
67-
fused_op->SetAttr("use_mkldnn", true);
76+
fused_op->SetAttr("use_onednn", true);
6877

6978
auto attr_map = GetAttributeMap(act_type);
7079
for (const auto& attr : attr_map) {

0 commit comments

Comments
 (0)