Skip to content

Commit 7647d40

Browse files
authored
Update quant_conv2d_dequant_fuse_pass.cc (#36821)
1 parent f20c5c9 commit 7647d40

File tree

1 file changed

+72
-15
lines changed

1 file changed

+72
-15
lines changed

paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,22 @@ QuantDequantFusePass::QuantDequantFusePass() {
210210
.AddAttr("y_num_col_dims")
211211
.IsNumEQ(1)
212212
.End();
213+
AddOpCompat(OpCompat("matmul_v2"))
214+
.AddInput("X")
215+
.IsTensor()
216+
.End()
217+
.AddInput("Y")
218+
.IsTensor()
219+
.End()
220+
.AddOutput("Out")
221+
.IsTensor()
222+
.End()
223+
.AddAttr("trans_x")
224+
.IsBoolEQ(false)
225+
.End()
226+
.AddAttr("trans_y")
227+
.IsBoolEQ(false)
228+
.End();
213229
AddOpCompat(OpCompat("matmul"))
214230
.AddInput("X")
215231
.IsTensor()
@@ -355,7 +371,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
355371
quantized_op_type == "fc" ||
356372
quantized_op_type == "conv2d_transpose") {
357373
op_desc->SetAttr("Input_scale", scale_value);
358-
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
374+
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
375+
quantized_op_type == "matmul_v2") {
359376
op_desc->SetAttr("X_scale", scale_value);
360377
} else {
361378
PADDLE_THROW(platform::errors::Unimplemented(
@@ -387,7 +404,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
387404
quantized_op_type == "conv2d_transpose") {
388405
weight_name = "Filter";
389406
input_name = "Input";
390-
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
407+
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
408+
quantized_op_type == "matmul_v2") {
391409
weight_name = "Y";
392410
input_name = "X";
393411
} else if (quantized_op_type == "fc") {
@@ -396,7 +414,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
396414
} else {
397415
PADDLE_THROW(platform::errors::Unimplemented(
398416
"QuantDequantFuse: We only support conv2d, conv2d_fusion, "
399-
"conv2d_transpose, fc, mul, matmul for "
417+
"conv2d_transpose, fc, mul, matmul, matmul_v2 for "
400418
"now."));
401419
}
402420
const std::string pattern_name = "dequant_fuse";
@@ -437,7 +455,11 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
437455
BOOST_GET_CONST(int, quantized_op_node->Op()->GetAttr("bit_length"));
438456
int range = ((1 << (bit_length - 1)) - 1);
439457
std::vector<float> weight_scale;
440-
458+
int quant_axis = 0;
459+
if (dequant_op_node->Op()->HasAttr("quant_axis")) {
460+
quant_axis =
461+
BOOST_GET_CONST(int, dequant_op_node->Op()->GetAttr("quant_axis"));
462+
}
441463
// Get weight scale
442464
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
443465
Node* dequant_channel_scale_node =
@@ -475,25 +497,37 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
475497
// If quantized op is conv2d, weight scale size = weight dims[0]
476498
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
477499
if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
478-
quantized_op_type == "fc") {
500+
quantized_op_type == "matmul_v2" || quantized_op_type == "fc") {
479501
if (dequant_type == "fake_dequantize_max_abs") {
480-
PADDLE_ENFORCE_EQ(
481-
weight_scale.size(), 1,
482-
platform::errors::InvalidArgument(
483-
"mul/matmul op weight dequantized by [fake_dequantize_max_abs] "
484-
"requires weight scale size = 1, but got %d.",
485-
weight_scale.size()));
502+
PADDLE_ENFORCE_EQ(weight_scale.size(), 1,
503+
platform::errors::InvalidArgument(
504+
"mul/matmul/matmul_v2 op weight dequantized by "
505+
"[fake_dequantize_max_abs] "
506+
"requires weight scale size = 1, but got %d.",
507+
weight_scale.size()));
486508
for (int j = 0; j < weight_tensor->numel(); j++) {
487509
quantized_weight_data[j] *= weight_scale[0];
488510
}
489511
}
490512
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
513+
if (quant_axis == 0) {
514+
} else {
515+
PADDLE_ENFORCE_EQ(
516+
quant_axis == 1, true,
517+
platform::errors::InvalidArgument(
518+
"'quant_axis' of mul/matmul/fc/matmul_v2 op weight "
519+
"dequantized by "
520+
"[fake_channel_wise_dequantize_max_abs]should be 1, but "
521+
"the received is %d",
522+
quant_axis));
523+
}
491524
PADDLE_ENFORCE_EQ(
492525
weight_scale.size(), static_cast<size_t>(w_dims[1]),
493526
platform::errors::InvalidArgument(
494-
"mul/matmul op weight dequantized by "
527+
"mul/matmul/matmul_v2 op weight dequantized by "
495528
"[fake_channel_wise_dequantize_max_abs] requires weight scale "
496-
"size = 2nd dim of mul/matmul's weight, which is %d, but got "
529+
"size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, "
530+
"but got "
497531
"%d.",
498532
static_cast<size_t>(w_dims[1]), weight_scale.size()));
499533
for (int j = 0; j < weight_tensor->numel(); j++) {
@@ -511,6 +545,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
511545
"model, please set the 'weight_quantize_type' params as "
512546
"'channel_wise_abs_max' and generate the quantized model again.",
513547
dequant_type));
548+
if (quant_axis == 0) {
549+
} else {
550+
PADDLE_ENFORCE_EQ(
551+
quant_axis == 0, true,
552+
platform::errors::InvalidArgument(
553+
"'quant_axis' of conv2d/depthwise_conv2d op weight dequantized "
554+
"by [fake_channel_wise_dequantize_max_abs]should be 0, but "
555+
"the received is %d",
556+
quant_axis));
557+
}
514558
PADDLE_ENFORCE_EQ(
515559
weight_scale.size(), static_cast<size_t>(w_dims[0]),
516560
platform::errors::InvalidArgument(
@@ -528,6 +572,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
528572
"conv2d_transpose must be dequantized by "
529573
"[fake_channel_wise_dequantize_max_abs], but got %s",
530574
dequant_type));
575+
if (quant_axis == 0) {
576+
} else {
577+
PADDLE_ENFORCE_EQ(
578+
quant_axis == 1, true,
579+
platform::errors::InvalidArgument(
580+
"'quant_axis' of conv2d_transpose op weight dequantized by "
581+
"[fake_channel_wise_dequantize_max_abs]should be 1, but "
582+
"the received is %d",
583+
quant_axis));
584+
}
531585
PADDLE_ENFORCE_EQ(
532586
weight_scale.size(), static_cast<size_t>(w_dims[1]),
533587
platform::errors::InvalidArgument(
@@ -560,7 +614,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
560614
} else if (quantized_op_type == "fc") {
561615
new_op_desc.SetInput("Input", {new_input});
562616
new_op_desc.SetOutput("Out", {new_output});
563-
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
617+
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
618+
quantized_op_type == "matmul_v2") {
564619
new_op_desc.SetInput("X", {new_input});
565620
new_op_desc.SetOutput("Out", {new_output});
566621
}
@@ -587,7 +642,9 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
587642
std::unordered_set<std::string> quant_types = {
588643
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
589644
std::unordered_set<std::string> quantized_op_types = {
590-
"conv2d", "mul", "matmul", "depthwise_conv2d", "fc", "conv2d_transpose"};
645+
"conv2d", "mul", "matmul", "depthwise_conv2d",
646+
"conv2d_transpose", "fc", "matmul_v2",
647+
};
591648
auto* scope = param_scope();
592649

593650
for (auto& quant_type : quant_types) {

0 commit comments

Comments
 (0)