@@ -2342,7 +2342,9 @@ PDNode *patterns::QuantConv::operator()(const std::string &conv_type) {
2342
2342
auto conv_op = pattern->NewNode (conv_op_repr ())->assert_is_op (conv_type);
2343
2343
conv_op->assert_more ([&](Node *node) {
2344
2344
return node->Op ()->GetAttrIfExists <std::string>(" mkldnn_data_type" ) ==
2345
- " bfloat16" ;
2345
+ " bfloat16" ||
2346
+ node->Op ()->GetAttrIfExists <std::string>(" onednn_data_type" ) ==
2347
+ " bfloat16" ;
2346
2348
});
2347
2349
2348
2350
quant_op->LinksFrom ({quant_in}).LinksTo ({conv_in});
@@ -3172,7 +3174,8 @@ PDNode *patterns::QuantizePlacement::operator()(
3172
3174
auto *op =
3173
3175
pattern->NewNode (op_repr ())->assert_is_ops (quantize_enabled_op_types);
3174
3176
op->assert_more ([&](Node *node) {
3175
- return node->Op ()->GetAttrIfExists <bool >(" use_mkldnn" );
3177
+ return node->Op ()->GetAttrIfExists <bool >(" use_mkldnn" ) ||
3178
+ node->Op ()->GetAttrIfExists <bool >(" use_onednn" );
3176
3179
});
3177
3180
return op;
3178
3181
}
@@ -3218,6 +3221,7 @@ PDNode *patterns::Bfloat16Placement::operator()(
3218
3221
auto *op = pattern->NewNode (op_repr ())->assert_is_ops (supported_op_types);
3219
3222
op->assert_more ([&](Node *node) {
3220
3223
return node->Op ()->GetAttrIfExists <bool >(" use_mkldnn" ) ||
3224
+ node->Op ()->GetAttrIfExists <bool >(" use_onednn" ) ||
3221
3225
node->Op ()->Type () == " reshape2" ;
3222
3226
});
3223
3227
op->LinksFrom ({op_in});
@@ -3227,25 +3231,35 @@ PDNode *patterns::Bfloat16Placement::operator()(
3227
3231
PDNode *patterns::OrphanedBfloat16::operator ()() {
3228
3232
auto *prev_op = pattern->NewNode (prev_op_repr ())->assert_is_op ();
3229
3233
prev_op->assert_more ([&](Node *node) {
3230
- bool data_type_is_missing = !node->Op ()->HasAttr (" mkldnn_data_type" );
3231
- bool data_type_is_fp32 = node->Op ()->GetAttrIfExists <std::string>(
3232
- " mkldnn_data_type" ) == " float32" ;
3234
+ bool data_type_is_missing = !node->Op ()->HasAttr (" mkldnn_data_type" ) &&
3235
+ !node->Op ()->HasAttr (" onednn_data_type" );
3236
+ bool data_type_is_fp32 =
3237
+ node->Op ()->GetAttrIfExists <std::string>(" mkldnn_data_type" ) ==
3238
+ " float32" ||
3239
+ node->Op ()->GetAttrIfExists <std::string>(" onednn_data_type" ) ==
3240
+ " float32" ;
3233
3241
return data_type_is_missing || data_type_is_fp32;
3234
3242
});
3235
3243
auto *prev_out = pattern->NewNode (prev_out_repr ())->AsOutput ();
3236
3244
3237
3245
auto *op = pattern->NewNode (op_repr ())->assert_is_op ();
3238
3246
op->assert_more ([&](Node *node) {
3239
3247
return node->Op ()->GetAttrIfExists <std::string>(" mkldnn_data_type" ) ==
3240
- " bfloat16" ;
3248
+ " bfloat16" ||
3249
+ node->Op ()->GetAttrIfExists <std::string>(" onednn_data_type" ) ==
3250
+ " bfloat16" ;
3241
3251
});
3242
3252
auto *op_out = pattern->NewNode (op_out_repr ())->AsOutput ();
3243
3253
3244
3254
auto *next_op = pattern->NewNode (next_op_repr ())->assert_is_op ();
3245
3255
next_op->assert_more ([&](Node *node) {
3246
- bool data_type_is_missing = !node->Op ()->HasAttr (" mkldnn_data_type" );
3247
- bool data_type_is_fp32 = node->Op ()->GetAttrIfExists <std::string>(
3248
- " mkldnn_data_type" ) == " float32" ;
3256
+ bool data_type_is_missing = !node->Op ()->HasAttr (" mkldnn_data_type" ) &&
3257
+ !node->Op ()->HasAttr (" onednn_data_type" );
3258
+ bool data_type_is_fp32 =
3259
+ node->Op ()->GetAttrIfExists <std::string>(" mkldnn_data_type" ) ==
3260
+ " float32" ||
3261
+ node->Op ()->GetAttrIfExists <std::string>(" onednn_data_type" ) ==
3262
+ " float32" ;
3249
3263
return data_type_is_missing || data_type_is_fp32;
3250
3264
});
3251
3265
@@ -3258,14 +3272,17 @@ PDNode *patterns::OrphanedBfloat16::operator()() {
3258
3272
PDNode *patterns::UnsupportedBfloat16::operator ()() {
3259
3273
auto *prev_op = pattern->NewNode (prev_op_repr ())->assert_is_op ();
3260
3274
prev_op->assert_more ([&](Node *node) {
3261
- return node->Op ()->HasAttr (" mkldnn_data_type" ) == false ;
3275
+ return node->Op ()->HasAttr (" mkldnn_data_type" ) == false &&
3276
+ node->Op ()->HasAttr (" onednn_data_type" ) == false ;
3262
3277
});
3263
3278
auto *prev_out = pattern->NewNode (prev_out_repr ())->AsOutput ();
3264
3279
3265
3280
auto *op = pattern->NewNode (op_repr ())->assert_is_op ();
3266
3281
op->assert_more ([&](Node *node) {
3267
3282
return node->Op ()->GetAttrIfExists <std::string>(" mkldnn_data_type" ) ==
3268
- " bfloat16" ;
3283
+ " bfloat16" ||
3284
+ node->Op ()->GetAttrIfExists <std::string>(" onednn_data_type" ) ==
3285
+ " bfloat16" ;
3269
3286
});
3270
3287
prev_op->LinksTo ({prev_out});
3271
3288
op->LinksFrom ({prev_out});
@@ -3276,7 +3293,9 @@ PDNode *patterns::Bloat16Ops::operator()() {
3276
3293
auto op = pattern->NewNode (op_repr ())->assert_is_op ();
3277
3294
op->assert_more ([&](Node *node) {
3278
3295
return node->Op ()->GetAttrIfExists <std::string>(" mkldnn_data_type" ) ==
3279
- " bfloat16" ;
3296
+ " bfloat16" ||
3297
+ node->Op ()->GetAttrIfExists <std::string>(" onednn_data_type" ) ==
3298
+ " bfloat16" ;
3280
3299
});
3281
3300
return op;
3282
3301
}
@@ -3298,8 +3317,8 @@ PDNode *patterns::ONEDNNInPlace::operator()() {
3298
3317
auto next_op = pattern->NewNode (next_op_repr ())->assert_is_op ();
3299
3318
auto next_output = pattern->NewNode (next_op_out_repr ())->AsOutput ();
3300
3319
3301
- // Check if op is MKL -DNN enabled
3302
- possible_inplace_op->assert_op_attr (" use_mkldnn" , true );
3320
+ // Check if op is ONE -DNN enabled
3321
+ possible_inplace_op->assert_op_attr_or (" use_mkldnn" , " use_onednn " , true );
3303
3322
3304
3323
// linked structure
3305
3324
possible_inplace_op->LinksTo ({output});
0 commit comments