@@ -210,6 +210,22 @@ QuantDequantFusePass::QuantDequantFusePass() {
210
210
.AddAttr (" y_num_col_dims" )
211
211
.IsNumEQ (1 )
212
212
.End ();
213
+ AddOpCompat (OpCompat (" matmul_v2" ))
214
+ .AddInput (" X" )
215
+ .IsTensor ()
216
+ .End ()
217
+ .AddInput (" Y" )
218
+ .IsTensor ()
219
+ .End ()
220
+ .AddOutput (" Out" )
221
+ .IsTensor ()
222
+ .End ()
223
+ .AddAttr (" trans_x" )
224
+ .IsBoolEQ (false )
225
+ .End ()
226
+ .AddAttr (" trans_y" )
227
+ .IsBoolEQ (false )
228
+ .End ();
213
229
AddOpCompat (OpCompat (" matmul" ))
214
230
.AddInput (" X" )
215
231
.IsTensor ()
@@ -355,7 +371,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
355
371
quantized_op_type == " fc" ||
356
372
quantized_op_type == " conv2d_transpose" ) {
357
373
op_desc->SetAttr (" Input_scale" , scale_value);
358
- } else if (quantized_op_type == " mul" || quantized_op_type == " matmul" ) {
374
+ } else if (quantized_op_type == " mul" || quantized_op_type == " matmul" ||
375
+ quantized_op_type == " matmul_v2" ) {
359
376
op_desc->SetAttr (" X_scale" , scale_value);
360
377
} else {
361
378
PADDLE_THROW (platform::errors::Unimplemented (
@@ -387,7 +404,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
387
404
quantized_op_type == " conv2d_transpose" ) {
388
405
weight_name = " Filter" ;
389
406
input_name = " Input" ;
390
- } else if (quantized_op_type == " mul" || quantized_op_type == " matmul" ) {
407
+ } else if (quantized_op_type == " mul" || quantized_op_type == " matmul" ||
408
+ quantized_op_type == " matmul_v2" ) {
391
409
weight_name = " Y" ;
392
410
input_name = " X" ;
393
411
} else if (quantized_op_type == " fc" ) {
@@ -396,7 +414,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
396
414
} else {
397
415
PADDLE_THROW (platform::errors::Unimplemented (
398
416
" QuantDequantFuse: We only support conv2d, conv2d_fusion, "
399
- " conv2d_transpose, fc, mul, matmul for "
417
+ " conv2d_transpose, fc, mul, matmul, matmul_v2 for "
400
418
" now." ));
401
419
}
402
420
const std::string pattern_name = " dequant_fuse" ;
@@ -437,7 +455,11 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
437
455
BOOST_GET_CONST (int , quantized_op_node->Op ()->GetAttr (" bit_length" ));
438
456
int range = ((1 << (bit_length - 1 )) - 1 );
439
457
std::vector<float > weight_scale;
440
-
458
+ int quant_axis = 0 ;
459
+ if (dequant_op_node->Op ()->HasAttr (" quant_axis" )) {
460
+ quant_axis =
461
+ BOOST_GET_CONST (int , dequant_op_node->Op ()->GetAttr (" quant_axis" ));
462
+ }
441
463
// Get weight scale
442
464
if (dequant_type == " fake_channel_wise_dequantize_max_abs" ) {
443
465
Node* dequant_channel_scale_node =
@@ -475,25 +497,37 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
475
497
// If quantized op is conv2d, weight scale size = weight dims[0]
476
498
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
477
499
if (quantized_op_type == " mul" || quantized_op_type == " matmul" ||
478
- quantized_op_type == " fc" ) {
500
+ quantized_op_type == " matmul_v2 " || quantized_op_type == " fc" ) {
479
501
if (dequant_type == " fake_dequantize_max_abs" ) {
480
- PADDLE_ENFORCE_EQ (
481
- weight_scale. size (), 1 ,
482
- platform::errors::InvalidArgument (
483
- " mul/matmul op weight dequantized by [fake_dequantize_max_abs] "
484
- " requires weight scale size = 1, but got %d." ,
485
- weight_scale.size ()));
502
+ PADDLE_ENFORCE_EQ (weight_scale. size (), 1 ,
503
+ platform::errors::InvalidArgument (
504
+ " mul/matmul/matmul_v2 op weight dequantized by "
505
+ " [fake_dequantize_max_abs] "
506
+ " requires weight scale size = 1, but got %d." ,
507
+ weight_scale.size ()));
486
508
for (int j = 0 ; j < weight_tensor->numel (); j++) {
487
509
quantized_weight_data[j] *= weight_scale[0 ];
488
510
}
489
511
}
490
512
if (dequant_type == " fake_channel_wise_dequantize_max_abs" ) {
513
+ if (quant_axis == 0 ) {
514
+ } else {
515
+ PADDLE_ENFORCE_EQ (
516
+ quant_axis == 1 , true ,
517
+ platform::errors::InvalidArgument (
518
+ " 'quant_axis' of mul/matmul/fc/matmul_v2 op weight "
519
+ " dequantized by "
520
+ " [fake_channel_wise_dequantize_max_abs]should be 1, but "
521
+ " the received is %d" ,
522
+ quant_axis));
523
+ }
491
524
PADDLE_ENFORCE_EQ (
492
525
weight_scale.size (), static_cast <size_t >(w_dims[1 ]),
493
526
platform::errors::InvalidArgument (
494
- " mul/matmul op weight dequantized by "
527
+ " mul/matmul/matmul_v2 op weight dequantized by "
495
528
" [fake_channel_wise_dequantize_max_abs] requires weight scale "
496
- " size = 2nd dim of mul/matmul's weight, which is %d, but got "
529
+ " size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, "
530
+ " but got "
497
531
" %d." ,
498
532
static_cast <size_t >(w_dims[1 ]), weight_scale.size ()));
499
533
for (int j = 0 ; j < weight_tensor->numel (); j++) {
@@ -511,6 +545,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
511
545
" model, please set the 'weight_quantize_type' params as "
512
546
" 'channel_wise_abs_max' and generate the quantized model again." ,
513
547
dequant_type));
548
+ if (quant_axis == 0 ) {
549
+ } else {
550
+ PADDLE_ENFORCE_EQ (
551
+ quant_axis == 0 , true ,
552
+ platform::errors::InvalidArgument (
553
+ " 'quant_axis' of conv2d/depthwise_conv2d op weight dequantized "
554
+ " by [fake_channel_wise_dequantize_max_abs]should be 0, but "
555
+ " the received is %d" ,
556
+ quant_axis));
557
+ }
514
558
PADDLE_ENFORCE_EQ (
515
559
weight_scale.size (), static_cast <size_t >(w_dims[0 ]),
516
560
platform::errors::InvalidArgument (
@@ -528,6 +572,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
528
572
" conv2d_transpose must be dequantized by "
529
573
" [fake_channel_wise_dequantize_max_abs], but got %s" ,
530
574
dequant_type));
575
+ if (quant_axis == 0 ) {
576
+ } else {
577
+ PADDLE_ENFORCE_EQ (
578
+ quant_axis == 1 , true ,
579
+ platform::errors::InvalidArgument (
580
+ " 'quant_axis' of conv2d_transpose op weight dequantized by "
581
+ " [fake_channel_wise_dequantize_max_abs]should be 1, but "
582
+ " the received is %d" ,
583
+ quant_axis));
584
+ }
531
585
PADDLE_ENFORCE_EQ (
532
586
weight_scale.size (), static_cast <size_t >(w_dims[1 ]),
533
587
platform::errors::InvalidArgument (
@@ -560,7 +614,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
560
614
} else if (quantized_op_type == " fc" ) {
561
615
new_op_desc.SetInput (" Input" , {new_input});
562
616
new_op_desc.SetOutput (" Out" , {new_output});
563
- } else if (quantized_op_type == " mul" || quantized_op_type == " matmul" ) {
617
+ } else if (quantized_op_type == " mul" || quantized_op_type == " matmul" ||
618
+ quantized_op_type == " matmul_v2" ) {
564
619
new_op_desc.SetInput (" X" , {new_input});
565
620
new_op_desc.SetOutput (" Out" , {new_output});
566
621
}
@@ -587,7 +642,9 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
587
642
std::unordered_set<std::string> quant_types = {
588
643
" fake_quantize_range_abs_max" , " fake_quantize_moving_average_abs_max" };
589
644
std::unordered_set<std::string> quantized_op_types = {
590
- " conv2d" , " mul" , " matmul" , " depthwise_conv2d" , " fc" , " conv2d_transpose" };
645
+ " conv2d" , " mul" , " matmul" , " depthwise_conv2d" ,
646
+ " conv2d_transpose" , " fc" , " matmul_v2" ,
647
+ };
591
648
auto * scope = param_scope ();
592
649
593
650
for (auto & quant_type : quant_types) {
0 commit comments