14
14
15
15
#include " paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h"
16
16
17
+ #include < algorithm>
17
18
#include < memory>
18
19
#include < string>
19
20
#include < unordered_set>
@@ -75,6 +76,12 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
75
76
any_op2_desc->Flush ();
76
77
auto dequant_type = quant_dequant_op->Op ()->Type ();
77
78
auto quantized_op_type = any_op2_desc->Type ();
79
+ // get weight tensor
80
+ auto * weight_tensor =
81
+ scope->GetVar (quant_dequant_op_x->Name ())->GetMutable <LoDTensor>();
82
+ auto w_dims = weight_tensor->dims ();
83
+ float * quantized_weight_data =
84
+ weight_tensor->mutable_data <float >(platform::CPUPlace ());
78
85
79
86
// Get weight scale
80
87
if (dequant_type == " fake_channel_wise_quantize_dequantize_abs_max" ) {
@@ -90,26 +97,64 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
90
97
paddle::platform::is_cpu_place (channel_scale_tensor.place ()),
91
98
platform::errors::InvalidArgument (
92
99
" Channel scale tensor's place should be CPU." ));
93
- const float * channel_scale_data = channel_scale_tensor.data <float >();
94
- for (int i = 0 ; i < channel_scale_tensor.numel (); i++) {
95
- weight_scale.push_back (range / channel_scale_data[i]);
100
+ // compute the channel wise abs max of the weight tensor
101
+ int quant_axis =
102
+ BOOST_GET_CONST (int , quant_dequant_op->Op ()->GetAttr (" quant_axis" ));
103
+
104
+ PADDLE_ENFORCE_EQ (quant_axis == 0 || quant_axis == 1 , true ,
105
+ platform::errors::InvalidArgument (
106
+ " 'quant_axis' should be 0 or 1, but "
107
+ " the received is %d" ,
108
+ quant_axis));
109
+
110
+ const int64_t channel = w_dims[quant_axis];
111
+ weight_scale.resize (channel, 0 );
112
+ if (quant_axis == 0 ) {
113
+ const int64_t channel_size = weight_tensor->numel () / channel;
114
+ for (int64_t i = 0 ; i < channel; i++) {
115
+ auto * start = quantized_weight_data + i * channel_size;
116
+ for (int64_t j = 0 ; j < channel_size; j++) {
117
+ weight_scale[i] = std::max (std::abs (start[j]), weight_scale[i]);
118
+ }
119
+ }
120
+ } else if (quant_axis == 1 ) {
121
+ const int64_t step_i = weight_tensor->numel () / w_dims[0 ];
122
+ const int64_t step_j = weight_tensor->numel () / (w_dims[0 ] * w_dims[1 ]);
123
+ for (int64_t i = 0 ; i < w_dims[0 ]; i++) {
124
+ for (int64_t j = 0 ; j < w_dims[1 ]; j++) {
125
+ auto * start = quantized_weight_data + i * step_i + j * step_j;
126
+ float abs_max = 0 ;
127
+ for (int64_t k = 0 ; k < step_j; k++) {
128
+ abs_max = std::max (std::abs (start[k]), abs_max);
129
+ }
130
+ weight_scale[j] = std::max (weight_scale[j], abs_max);
131
+ }
132
+ }
133
+ }
134
+ for (int i = 0 ; i < channel; i++) {
135
+ PADDLE_ENFORCE_NE (weight_scale[i], 0 ,
136
+ platform::errors::InvalidArgument (
137
+ " Weight scale should be nonzero, but get zero." ));
138
+ weight_scale[i] = range / weight_scale[i];
96
139
}
97
140
} else {
98
141
auto scale_name = quant_dequant_op_outscale->Name ();
99
- const LoDTensor& scale_tensor =
100
- scope->GetVar (scale_name)->Get <LoDTensor>();
101
- const float * scale_data = scale_tensor.data <float >();
102
- weight_scale.push_back ((range * range) / scale_data[0 ] / range);
142
+ // compute the abs max of the weight tensor
143
+ float abs_max_weight = 0 .;
144
+ for (int j = 0 ; j < weight_tensor->numel (); j++) {
145
+ abs_max_weight =
146
+ std::max (abs_max_weight, std::abs (quantized_weight_data[j]));
147
+ }
148
+ PADDLE_ENFORCE_NE (abs_max_weight, 0 ,
149
+ platform::errors::InvalidArgument (
150
+ " Weight scale should be nonzero, but get zero" ));
151
+ weight_scale.push_back ((range * range) / abs_max_weight / range);
103
152
}
104
153
105
154
nodes2rm.insert (quant_dequant_op_outscale);
155
+
106
156
// perform quantize dequantize operations
107
- auto * weight_tensor =
108
- scope->GetVar (quant_dequant_op_x->Name ())->GetMutable <LoDTensor>();
109
- auto w_dims = weight_tensor->dims ();
110
- float * quantized_weight_data =
111
- weight_tensor->mutable_data <float >(platform::CPUPlace ());
112
- // If quantized op is fc, weight scale size = 1;
157
+ // If quantized op is not channel wise, weight scale size = 1;
113
158
// If quantized op is conv2d, weight scale size = weight dims[0]
114
159
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
115
160
if (dequant_type == " fake_quantize_dequantize_abs_max" ) {
@@ -119,9 +164,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
119
164
" %s op weight dequantized by [fake_quantize_dequantize_max_abs] "
120
165
" requires weight scale size = 1, but got %d." ,
121
166
quantized_op_type, weight_scale.size ()));
122
- PADDLE_ENFORCE_NE (weight_scale[0 ], 0 ,
123
- platform::errors::InvalidArgument (
124
- " Weight scale should be nonzero, but get zero" ));
125
167
for (int j = 0 ; j < weight_tensor->numel (); j++) {
126
168
// quantized
127
169
quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0 ];
0 commit comments