Skip to content

Commit 9ffb43b

Browse files
ceci3XGZhang11
andauthored
Fix a bug of quantization (#36982) (#37381)
* fix a quantization bug Co-authored-by: XGZhang <[email protected]>
1 parent 109f8a8 commit 9ffb43b

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

python/paddle/fluid/contrib/slim/quantization/quantization_pass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,10 +1292,11 @@ def _insert_post_channel_dequant_op(self, graph, op_node, quant_axis):
12921292
var_type=output_var_node.type(),
12931293
shape=output_var_node.shape(),
12941294
var_dtype=output_var_node.dtype())
1295+
x_num_col_dims = 1
1296+
if op_node.name() in ['matmul', 'matmul_v2', 'mul']:
1297+
x_num_col_dims = len(op_node.outputs[0].shape()) - 1
12951298
if op_node.op().has_attr("x_num_col_dims"):
12961299
x_num_col_dims = op_node.op().attr("x_num_col_dims")
1297-
else:
1298-
x_num_col_dims = 1
12991300
dequant_op_node = graph.create_op_node(
13001301
op_type='fake_channel_wise_dequantize_max_abs',
13011302
attrs={

0 commit comments

Comments
 (0)