@@ -158,14 +158,24 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
158
158
auto conv_weight_d = conv_weight_t ->mutable_data <int8_t >();
159
159
// compute new conv_weight for int8
160
160
auto weight_scale = conv_op_desc->GetInputScale (weight_name);
161
+ std::vector<float > weight_scale_dup (alpha_tensor.numel (), 0 );
162
+ if (weight_scale.size () == 1 ) {
163
+ for (int i = 0 ; i < weight_scale_dup.size (); i++) {
164
+ weight_scale_dup[i] = weight_scale[0 ];
165
+ }
166
+ } else {
167
+ for (int i = 0 ; i < weight_scale_dup.size (); i++) {
168
+ weight_scale_dup[i] = weight_scale[i];
169
+ }
170
+ }
161
171
if (conv_type_ == " conv2d_transpose" ) {
162
172
int cout = conv_weight_t ->dims ()[1 ] * groups;
163
173
int cin_group = conv_weight_t ->dims ()[0 ] / groups;
164
174
int c_size = cout * conv_weight_t ->dims ()[2 ] * conv_weight_t ->dims ()[3 ];
165
175
int hw = conv_weight_t ->dims ()[2 ] * conv_weight_t ->dims ()[3 ];
166
176
for (int k = 0 ; k < cin_group; ++k) {
167
177
for (int i = 0 ; i < cout; ++i) {
168
- weight_scale [i] *= fabsf (alpha_data[i]);
178
+ weight_scale_dup [i] *= fabsf (alpha_data[i]);
169
179
if (alpha_data[i] < 0 .f ) {
170
180
auto ptr_row = conv_weight_d + k * c_size + i * hw;
171
181
for (int j = 0 ; j < hw; ++j) {
@@ -176,7 +186,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
176
186
}
177
187
} else {
178
188
for (int i = 0 ; i < h; ++i) {
179
- weight_scale [i] *= fabsf (alpha_data[i]);
189
+ weight_scale_dup [i] *= fabsf (alpha_data[i]);
180
190
if (alpha_data[i] < 0 .f ) {
181
191
auto ptr_row = conv_weight_d + i * w;
182
192
for (int j = 0 ; j < w; ++j) {
@@ -185,7 +195,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
185
195
}
186
196
}
187
197
}
188
- conv_op_desc->SetInputScale (weight_name, weight_scale );
198
+ conv_op_desc->SetInputScale (weight_name, weight_scale_dup );
189
199
} else if (is_weight_quantization) {
190
200
std::string scale_name = conv_weight_name + " _quant_scale" ;
191
201
if (conv_op_desc->HasAttr (scale_name)) {
0 commit comments