Skip to content

Commit 011a6a5

Browse files
authored
added support for fake_quantize_dequantize_abs_max op in quantization… (#30896) (#31162)
* added support for fake_quantize_dequantize_abs_max op in quantization inference pass * remove const_cast to pass ci * remove compare operator to pass ci-coverage * added detailed error message for unregistered tensorrt_subgrah_pass
1 parent b0ec6e8 commit 011a6a5

File tree

2 files changed

+71
-19
lines changed

2 files changed

+71
-19
lines changed

paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h"
1616

17+
#include <algorithm>
1718
#include <memory>
1819
#include <string>
1920
#include <unordered_set>
@@ -75,6 +76,12 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
7576
any_op2_desc->Flush();
7677
auto dequant_type = quant_dequant_op->Op()->Type();
7778
auto quantized_op_type = any_op2_desc->Type();
79+
// get weight tensor
80+
auto* weight_tensor =
81+
scope->GetVar(quant_dequant_op_x->Name())->GetMutable<LoDTensor>();
82+
auto w_dims = weight_tensor->dims();
83+
float* quantized_weight_data =
84+
weight_tensor->mutable_data<float>(platform::CPUPlace());
7885

7986
// Get weight scale
8087
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
@@ -90,26 +97,64 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
9097
paddle::platform::is_cpu_place(channel_scale_tensor.place()),
9198
platform::errors::InvalidArgument(
9299
"Channel scale tensor's place should be CPU."));
93-
const float* channel_scale_data = channel_scale_tensor.data<float>();
94-
for (int i = 0; i < channel_scale_tensor.numel(); i++) {
95-
weight_scale.push_back(range / channel_scale_data[i]);
100+
// compute the channel wise abs max of the weight tensor
101+
int quant_axis =
102+
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
103+
104+
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
105+
platform::errors::InvalidArgument(
106+
"'quant_axis' should be 0 or 1, but "
107+
"the received is %d",
108+
quant_axis));
109+
110+
const int64_t channel = w_dims[quant_axis];
111+
weight_scale.resize(channel, 0);
112+
if (quant_axis == 0) {
113+
const int64_t channel_size = weight_tensor->numel() / channel;
114+
for (int64_t i = 0; i < channel; i++) {
115+
auto* start = quantized_weight_data + i * channel_size;
116+
for (int64_t j = 0; j < channel_size; j++) {
117+
weight_scale[i] = std::max(std::abs(start[j]), weight_scale[i]);
118+
}
119+
}
120+
} else if (quant_axis == 1) {
121+
const int64_t step_i = weight_tensor->numel() / w_dims[0];
122+
const int64_t step_j = weight_tensor->numel() / (w_dims[0] * w_dims[1]);
123+
for (int64_t i = 0; i < w_dims[0]; i++) {
124+
for (int64_t j = 0; j < w_dims[1]; j++) {
125+
auto* start = quantized_weight_data + i * step_i + j * step_j;
126+
float abs_max = 0;
127+
for (int64_t k = 0; k < step_j; k++) {
128+
abs_max = std::max(std::abs(start[k]), abs_max);
129+
}
130+
weight_scale[j] = std::max(weight_scale[j], abs_max);
131+
}
132+
}
133+
}
134+
for (int i = 0; i < channel; i++) {
135+
PADDLE_ENFORCE_NE(weight_scale[i], 0,
136+
platform::errors::InvalidArgument(
137+
"Weight scale should be nonzero, but get zero."));
138+
weight_scale[i] = range / weight_scale[i];
96139
}
97140
} else {
98141
auto scale_name = quant_dequant_op_outscale->Name();
99-
const LoDTensor& scale_tensor =
100-
scope->GetVar(scale_name)->Get<LoDTensor>();
101-
const float* scale_data = scale_tensor.data<float>();
102-
weight_scale.push_back((range * range) / scale_data[0] / range);
142+
// compute the abs max of the weight tensor
143+
float abs_max_weight = 0.;
144+
for (int j = 0; j < weight_tensor->numel(); j++) {
145+
abs_max_weight =
146+
std::max(abs_max_weight, std::abs(quantized_weight_data[j]));
147+
}
148+
PADDLE_ENFORCE_NE(abs_max_weight, 0,
149+
platform::errors::InvalidArgument(
150+
"Weight scale should be nonzero, but get zero"));
151+
weight_scale.push_back((range * range) / abs_max_weight / range);
103152
}
104153

105154
nodes2rm.insert(quant_dequant_op_outscale);
155+
106156
// perform quantize dequantize operations
107-
auto* weight_tensor =
108-
scope->GetVar(quant_dequant_op_x->Name())->GetMutable<LoDTensor>();
109-
auto w_dims = weight_tensor->dims();
110-
float* quantized_weight_data =
111-
weight_tensor->mutable_data<float>(platform::CPUPlace());
112-
// If quantized op is fc, weight scale size = 1;
157+
// If quantized op is not channel wise, weight scale size = 1;
113158
// If quantized op is conv2d, weight scale size = weight dims[0]
114159
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
115160
if (dequant_type == "fake_quantize_dequantize_abs_max") {
@@ -119,9 +164,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
119164
"%s op weight dequantized by [fake_quantize_dequantize_max_abs] "
120165
"requires weight scale size = 1, but got %d.",
121166
quantized_op_type, weight_scale.size()));
122-
PADDLE_ENFORCE_NE(weight_scale[0], 0,
123-
platform::errors::InvalidArgument(
124-
"Weight scale should be nonzero, but get zero"));
125167
for (int j = 0; j < weight_tensor->numel(); j++) {
126168
// quantized
127169
quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0];

paddle/fluid/framework/ir/pass.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,19 @@ class PassRegistry {
206206
}
207207

208208
std::unique_ptr<Pass> Get(const std::string &pass_type) const {
209-
PADDLE_ENFORCE_EQ(Has(pass_type), true,
210-
platform::errors::InvalidArgument(
211-
"Pass %s has not been registered.", pass_type));
209+
if (pass_type == "tensorrt_subgraph_pass") {
210+
PADDLE_ENFORCE_EQ(Has(pass_type), true,
211+
platform::errors::InvalidArgument(
212+
"Pass %s has not been registered. Please "
213+
"use the paddle inference library "
214+
"compiled with tensorrt or disable "
215+
"the tensorrt engine in inference configuration! ",
216+
pass_type));
217+
} else {
218+
PADDLE_ENFORCE_EQ(Has(pass_type), true,
219+
platform::errors::InvalidArgument(
220+
"Pass %s has not been registered.", pass_type));
221+
}
212222
return map_.at(pass_type)();
213223
}
214224

0 commit comments

Comments
 (0)