@@ -59,12 +59,21 @@ inline std::unordered_map<std::string, std::string> GetAttributeMap(
59
59
inline void SetActivationAttrs (paddle::framework::OpDesc* fused_op,
60
60
paddle::framework::OpDesc* act_op,
61
61
const std::string& act_type) {
62
- if (fused_op->HasAttr (" use_mkldnn" )) {
62
+ bool use_mkldnn = false ;
63
+ if (fused_op->HasAttr (" use_mkldnn" ) && !fused_op->HasAttr (" use_onednn" )) {
63
64
PADDLE_ENFORCE (PADDLE_GET_CONST (bool , fused_op->GetAttr (" use_mkldnn" )),
64
65
common::errors::PreconditionNotMet (
65
- " oneDNN activation fuses require use_mkldnn=True" ));
66
+ " oneDNN activation fuses require use_onednn=True" ));
67
+ }
68
+ if (fused_op->HasAttr (" use_mkldnn" )) {
69
+ use_mkldnn = PADDLE_GET_CONST (bool , fused_op->GetAttr (" use_mkldnn" ));
70
+ }
71
+ if (!use_mkldnn && fused_op->HasAttr (" use_onednn" )) {
72
+ PADDLE_ENFORCE (PADDLE_GET_CONST (bool , fused_op->GetAttr (" use_onednn" )),
73
+ common::errors::PreconditionNotMet (
74
+ " oneDNN activation fuses require use_onednn=True" ));
66
75
}
67
- fused_op->SetAttr (" use_mkldnn " , true );
76
+ fused_op->SetAttr (" use_onednn " , true );
68
77
69
78
auto attr_map = GetAttributeMap (act_type);
70
79
for (const auto & attr : attr_map) {
0 commit comments