Skip to content

Commit ff70a26

Browse files
authored
[cherry-pick]Update quantization round and clip calculation methods (#43829)
* update quantization clip and round * fix quantization clip and round Attribute * fix typo
1 parent 9e776f6 commit ff70a26

20 files changed

+2406
-1538
lines changed

paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
4545
.End()
4646
.AddAttr("bit_length")
4747
.IsIntIn({8, 16})
48+
.End()
49+
.AddAttr("round_type")
50+
.IsOptional()
51+
.IsIntIn({0, 1})
4852
.End();
4953
AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max"))
5054
.AddInput("X")
@@ -61,6 +65,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
6165
.End()
6266
.AddAttr("quant_axis")
6367
.IsIntIn({0, 1})
68+
.End()
69+
.AddAttr("round_type")
70+
.IsOptional()
71+
.IsIntIn({0, 1})
6472
.End();
6573
}
6674
// Delete quant_dequant_op, then quantize and dequantize weight
@@ -96,15 +104,18 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
96104
auto var_map = any_op2_desc->Inputs();
97105
std::string arg_name = "";
98106
for (auto& name_m : var_map) {
99-
if (std::find(name_m.second.begin(), name_m.second.end(),
107+
if (std::find(name_m.second.begin(),
108+
name_m.second.end(),
100109
quant_dequant_op_out_name) != name_m.second.end()) {
101110
arg_name = name_m.first;
102111
break;
103112
}
104113
}
105-
PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument(
106-
"can not find the input %s.",
107-
quant_dequant_op_out_name));
114+
PADDLE_ENFORCE_GT(
115+
arg_name.size(),
116+
0,
117+
platform::errors::InvalidArgument("can not find the input %s.",
118+
quant_dequant_op_out_name));
108119
// any_op2_desc->SetAttr("enable_int8", true);
109120
any_op2_desc->SetAttr("bit_length", bit_length);
110121

@@ -123,7 +134,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
123134
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
124135
int quant_axis =
125136
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
126-
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
137+
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1,
138+
true,
127139
platform::errors::InvalidArgument(
128140
"'quant_axis' should be 0 or 1, but "
129141
"the received is %d",
@@ -176,7 +188,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
176188
}
177189
}
178190
for (int i = 0; i < channel; i++) {
179-
PADDLE_ENFORCE_NE(weight_scale[i], 0,
191+
PADDLE_ENFORCE_NE(weight_scale[i],
192+
0,
180193
platform::errors::InvalidArgument(
181194
"Weight scale should be nonzero, but get zero."));
182195
weight_scale[i] = weight_scale[i] / range;
@@ -188,7 +201,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
188201
abs_max_weight =
189202
std::max(abs_max_weight, std::abs(quantized_weight_data[j]));
190203
}
191-
PADDLE_ENFORCE_NE(abs_max_weight, 0,
204+
PADDLE_ENFORCE_NE(abs_max_weight,
205+
0,
192206
platform::errors::InvalidArgument(
193207
"Weight scale should be nonzero, but get zero"));
194208
weight_scale.push_back(abs_max_weight / range);

paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
5454
.End()
5555
.AddAttr("quant_axis")
5656
.IsType<int>()
57+
.End()
58+
.AddAttr("round_type")
59+
.IsOptional()
60+
.IsType<int>()
5761
.End();
5862
AddOpCompat(OpCompat("dequantize_linear"))
5963
.AddInput("X")
@@ -74,6 +78,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
7478
.End()
7579
.AddAttr("quant_axis")
7680
.IsType<int>()
81+
.End()
82+
.AddAttr("round_type")
83+
.IsOptional()
84+
.IsType<int>()
7785
.End();
7886
}
7987
// Delete quantize_linear_op dequantize_linear_op, then add input_scales
@@ -112,7 +120,8 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
112120
const LoDTensor& input_scale_tensor =
113121
scope->GetVar(quantize_linear_op_scale->Name())->Get<LoDTensor>();
114122
PADDLE_ENFORCE_EQ(
115-
paddle::platform::is_cpu_place(input_scale_tensor.place()), true,
123+
paddle::platform::is_cpu_place(input_scale_tensor.place()),
124+
true,
116125
platform::errors::InvalidArgument(
117126
"Input scale tensor's place should be CPU."));
118127
const float* input_scale_data = input_scale_tensor.data<float>();

paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
5252
.End()
5353
.AddAttr("quant_axis")
5454
.IsType<int>()
55+
.End()
56+
.AddAttr("round_type")
57+
.IsOptional()
58+
.IsType<int>()
5559
.End();
5660
AddOpCompat(OpCompat("dequantize_linear"))
5761
.AddInput("X")
@@ -72,6 +76,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
7276
.End()
7377
.AddAttr("quant_axis")
7478
.IsType<int>()
79+
.End()
80+
.AddAttr("round_type")
81+
.IsOptional()
82+
.IsType<int>()
7583
.End();
7684
AddOpCompat(OpCompat("conv2d"))
7785
.AddInput("Input")
@@ -322,7 +330,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
322330
int quant_axis = BOOST_GET_CONST(
323331
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
324332
if (quant_axis == -1) { // per_layer quant_dequant: all OP
325-
PADDLE_ENFORCE_EQ(weight_scale_nums, 1,
333+
PADDLE_ENFORCE_EQ(weight_scale_nums,
334+
1,
326335
platform::errors::InvalidArgument(
327336
"When quant_axis == -1 means use per_layer "
328337
"quant_dequant, weight_scale'number should be 1."));
@@ -335,11 +344,13 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
335344
} else if (quant_axis == 0) { // per_channel quant_dequant: conv2d,
336345
// depthwise_conv2d, conv2d_fusion
337346
PADDLE_ENFORCE_EQ(
338-
weight_scale_nums, w_dims[quant_axis],
347+
weight_scale_nums,
348+
w_dims[quant_axis],
339349
platform::errors::InvalidArgument(
340350
"When quant_axis == 0 means use per_channel quant_dequant, "
341351
"weight_scale'numbers should be equal channels."));
342-
PADDLE_ENFORCE_EQ(w_dims.size(), 4,
352+
PADDLE_ENFORCE_EQ(w_dims.size(),
353+
4,
343354
platform::errors::InvalidArgument(
344355
"When quant_axis == 0 means use per_channel "
345356
"quant_dequant, (conv2d, depthwise_conv2d, "
@@ -352,15 +363,17 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
352363
}
353364
} else if (quant_axis == 1) {
354365
PADDLE_ENFORCE_EQ(
355-
weight_scale_nums, w_dims[quant_axis],
366+
weight_scale_nums,
367+
w_dims[quant_axis],
356368
platform::errors::InvalidArgument(
357369
"When quant_axis == 1 means use per_channel quant_dequant, "
358370
"weight_scale'numbers should be equal channels."));
359371

360372
if (w_dims.size() == 4) { // conv2d_transpose
361373
std::string quantized_op_type = any_op2->Op()->Type();
362374
PADDLE_ENFORCE_EQ(
363-
quantized_op_type, "conv2d_transpose",
375+
quantized_op_type,
376+
"conv2d_transpose",
364377
platform::errors::InvalidArgument(
365378
"When quant_axis == 1 means use per_channel quant_dequant, "
366379
"only conv2d_transpose weight dims equal 4."));
@@ -388,7 +401,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
388401
weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims)));
389402
float* new_quantized_weight_data =
390403
weight_tensor->mutable_data<float>(platform::CPUPlace());
391-
memcpy(new_quantized_weight_data, weight_data_tmp.data(),
404+
memcpy(new_quantized_weight_data,
405+
weight_data_tmp.data(),
392406
weight_tensor->numel() * sizeof(float));
393407

394408
nodes2rm.insert(weight_dequantize_linear_op_scale);

0 commit comments

Comments
 (0)