Skip to content

Commit 4d83a0e

Browse files
Implement const dequantize pushing for per-axis dequantization (#1321)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 70bc2b6 commit 4d83a0e

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

tests/test_optimizers.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,86 @@ def test_const_dequantize_reshape(self):
15121512
model_proto = self.make_model(graph, producer_name="onnx-tests")
15131513
self.run_and_compare(["Z"], {}, model_proto, "Reshape", 0)
15141514

1515+
@check_opset_min_version(13, "DequantizeLinear")
1516+
def test_const_dequantize_reshape_per_channel(self):
1517+
inputval = numpy_helper.from_array(np.random.randint(0, 100, (2, 3, 4, 5), np.uint8), name='X')
1518+
scale = numpy_helper.from_array(np.array([0.75, 1., 0.2], dtype=np.float32), name='scale')
1519+
zero_point = numpy_helper.from_array(np.array([3, 4, 50], dtype=np.uint8), name='zero_point')
1520+
shape = numpy_helper.from_array(np.array([1, 1, 2, 3, 20], dtype=np.int64), name='shape')
1521+
node1 = helper.make_node("DequantizeLinear", ["X", "scale", "zero_point"], ["Y"], name="dequantize", axis=-3)
1522+
node2 = helper.make_node("Reshape", ["Y", "shape"], ["Z"], name="reshape")
1523+
1524+
graph = helper.make_graph(
1525+
[node1, node2],
1526+
"const-dequantize-test",
1527+
[],
1528+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (1, 1, 2, 3, 20))],
1529+
[inputval, scale, zero_point, shape]
1530+
)
1531+
1532+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1533+
self.run_and_compare(["Z"], {}, model_proto, "Reshape", 0)
1534+
1535+
@check_opset_min_version(13, "DequantizeLinear")
1536+
def test_const_dequantize_reshape_per_channel_skipped(self):
1537+
inputval = numpy_helper.from_array(np.random.randint(0, 100, (2, 3, 4, 5), np.uint8), name='X')
1538+
scale = numpy_helper.from_array(np.array([0.75, 1., 0.2, 0.3], dtype=np.float32), name='scale')
1539+
zero_point = numpy_helper.from_array(np.array([3, 4, 50, 2], dtype=np.uint8), name='zero_point')
1540+
shape = numpy_helper.from_array(np.array([1, 6, 2, 2, 5], dtype=np.int64), name='shape')
1541+
node1 = helper.make_node("DequantizeLinear", ["X", "scale", "zero_point"], ["Y"], name="dequantize", axis=2)
1542+
node2 = helper.make_node("Reshape", ["Y", "shape"], ["Z"], name="reshape")
1543+
1544+
graph = helper.make_graph(
1545+
[node1, node2],
1546+
"const-dequantize-test",
1547+
[],
1548+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (1, 6, 2, 2, 5))],
1549+
[inputval, scale, zero_point, shape]
1550+
)
1551+
1552+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1553+
# No optimization can be done here since the channel axis has changed size
1554+
self.run_and_compare(["Z"], {}, model_proto, "Reshape", 1)
1555+
1556+
@check_opset_min_version(13, "DequantizeLinear")
1557+
def test_const_dequantize_transpose_per_channel(self):
1558+
inputval = numpy_helper.from_array(np.random.randint(0, 100, (2, 3, 4, 5), np.uint8), name='X')
1559+
scale = numpy_helper.from_array(np.array([0.75, 1., 0.2], dtype=np.float32), name='scale')
1560+
zero_point = numpy_helper.from_array(np.array([3, 4, 50], dtype=np.uint8), name='zero_point')
1561+
node1 = helper.make_node("DequantizeLinear", ["X", "scale", "zero_point"], ["Y"], name="dequantize", axis=1)
1562+
node2 = helper.make_node("Transpose", ["Y"], ["Z"], name="transpose", perm=[0, 2, 3, 1])
1563+
1564+
graph = helper.make_graph(
1565+
[node1, node2],
1566+
"const-dequantize-test",
1567+
[],
1568+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (2, 4, 5, 3))],
1569+
[inputval, scale, zero_point]
1570+
)
1571+
1572+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1573+
self.run_and_compare(["Z"], {}, model_proto, "Transpose", 0)
1574+
1575+
@check_opset_min_version(13, "DequantizeLinear")
1576+
def test_const_dequantize_unsqueeze_per_channel(self):
1577+
inputval = numpy_helper.from_array(np.random.randint(0, 100, (2, 3, 4, 5), np.uint8), name='X')
1578+
scale = numpy_helper.from_array(np.array([0.75, 1., 0.2], dtype=np.float32), name='scale')
1579+
zero_point = numpy_helper.from_array(np.array([3, 4, 50], dtype=np.uint8), name='zero_point')
1580+
axes = numpy_helper.from_array(np.array([-1, 0, -8, 3, 5], dtype=np.int64), name='axes')
1581+
node1 = helper.make_node("DequantizeLinear", ["X", "scale", "zero_point"], ["Y"], name="dequantize", axis=1)
1582+
node2 = helper.make_node("Unsqueeze", ["Y", "axes"], ["Z"], name="unsqueeze")
1583+
1584+
graph = helper.make_graph(
1585+
[node1, node2],
1586+
"const-dequantize-test",
1587+
[],
1588+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (1, 1, 2, 1, 3, 1, 4, 5, 1))],
1589+
[inputval, scale, zero_point, axes]
1590+
)
1591+
1592+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1593+
self.run_and_compare(["Z"], {}, model_proto, "Transpose", 0)
1594+
15151595
# Const Dequantize Optimizer Tests End
15161596

15171597
def test_transpose_back_to_back_non_const(self):

tf2onnx/optimizer/const_dequantize_optimizer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ def _fold_node(self, node, graph):
4545
return False
4646
if not self._all_inputs_are_const(dequant_node.inputs):
4747
return False
48+
if len(dequant_node.inputs[1].get_tensor_value(as_list=False).flatten()) != 1:
49+
# If using per-channel quantization, we must compute the new axis
50+
old_axis = dequant_node.get_attr_value("axis")
51+
input_shape = dequant_node.inputs[0].get_tensor_value(as_list=False).shape
52+
new_axis = self.compute_new_axis(node, graph, old_axis, input_shape)
53+
if new_axis is None:
54+
return False
55+
dequant_node.set_attr("axis", new_axis)
4856
graph.replace_input(node, node.input[0], dequant_node.input[0], 0)
4957
const_outputs = ConstFoldOptimizer.compute_const_folding(node, graph)
5058
graph.replace_all_inputs(node.output[0], dequant_node.output[0])
@@ -65,3 +73,40 @@ def _is_graph_output(node, graph):
6573
node_out_set = set(node.output)
6674
graph_out_set = set(graph.outputs)
6775
return node_out_set.intersection(graph_out_set)
76+
77+
@staticmethod
78+
def compute_new_axis(node, graph, old_axis, input_shape):
79+
if old_axis < 0:
80+
old_axis += len(input_shape)
81+
if node.type == "Transpose":
82+
perm = node.get_attr_value("perm")
83+
if perm is None:
84+
return None
85+
return perm.index(old_axis)
86+
if node.type == "Reshape":
87+
prod = 1
88+
for d in input_shape[:old_axis+1]:
89+
prod *= d
90+
new_shape = node.inputs[1].get_tensor_value(as_list=True)
91+
new_prod = 1
92+
for i, d in enumerate(new_shape):
93+
new_prod *= d
94+
if new_prod == prod:
95+
if new_shape[i] == input_shape[old_axis]:
96+
return i
97+
return None
98+
return None
99+
if node.type == "Unsqueeze":
100+
if graph.opset >= 13:
101+
axes = node.inputs[1].get_tensor_value(as_list=True)
102+
else:
103+
axes = node.get_attr_value("axes")
104+
new_rank = len(input_shape) + len(axes)
105+
axes = [axis if axis >= 0 else axis + new_rank for axis in axes]
106+
for i in range(new_rank):
107+
if i not in axes:
108+
if old_axis == 0:
109+
return i
110+
old_axis -= 1
111+
return None
112+
return None

0 commit comments

Comments
 (0)