Skip to content

Commit 96e1a03

Browse files
Added optimization for dequantized constant folding (#1316)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 4becde0 commit 96e1a03

File tree

4 files changed

+99
-2
lines changed

4 files changed

+99
-2
lines changed

tests/test_optimizers.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import unittest
1111
import numpy as np
12-
from onnx import helper, TensorProto, OperatorSetIdProto
12+
from onnx import helper, numpy_helper, TensorProto, OperatorSetIdProto
1313
from backend_test_base import Tf2OnnxBackendTestBase
1414
from common import unittest_main, group_nodes_by_type, check_opset_min_version, check_opset_max_version, get_test_config
1515
from tf2onnx import utils, constants
@@ -1241,6 +1241,30 @@ def test_const_fold_cast_with_const(self):
12411241

12421242
# Const Fold Optimizer Tests End
12431243

1244+
# Const Dequantize Optimizer Tests Start
1245+
1246+
@check_opset_min_version(10, "DequantizeLinear")
1247+
def test_const_dequantize_reshape(self):
1248+
inputval = numpy_helper.from_array(np.random.randint(0, 100, (2, 3, 4, 5), np.uint8), name='X')
1249+
scale = numpy_helper.from_array(np.array(0.75, dtype=np.float32), name='scale')
1250+
zero_point = numpy_helper.from_array(np.array(3, dtype=np.uint8), name='zero_point')
1251+
shape = numpy_helper.from_array(np.array([6, 20], dtype=np.int64), name='shape')
1252+
node1 = helper.make_node("DequantizeLinear", ["X", "scale", "zero_point"], ["Y"], name="dequantize")
1253+
node2 = helper.make_node("Reshape", ["Y", "shape"], ["Z"], name="reshape")
1254+
1255+
graph = helper.make_graph(
1256+
[node1, node2],
1257+
"const-dequantize-test",
1258+
[],
1259+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (6, 20))],
1260+
[inputval, scale, zero_point, shape]
1261+
)
1262+
1263+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1264+
self.run_and_compare(["Z"], {}, model_proto, "Reshape", 0)
1265+
1266+
# Const Dequantize Optimizer Tests End
1267+
12441268
def test_transpose_back_to_back_non_const(self):
12451269

12461270
node0 = helper.make_node("Transpose", ["u"], ["v"], perm=[0, 2, 3, 1], name="trans_0")

tf2onnx/optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
from .loop_optimizer import LoopOptimizer
1717
from .back_to_back_optimizer import BackToBackOptimizer
1818
from .upsample_optimizer import UpsampleOptimizer
19+
from .const_dequantize_optimizer import ConstDequantizeOptimizer
1920
from .. import logging
2021

2122
# optimizer sequence need to be considered carefully
2223
_optimizers = OrderedDict([
2324
("optimize_transpose", TransposeOptimizer),
2425
("remove_redundant_upsample", UpsampleOptimizer),
2526
("fold_constants", ConstFoldOptimizer),
27+
("const_dequantize_optimizer", ConstDequantizeOptimizer),
2628
("loop_optimizer", LoopOptimizer),
2729
# merge_duplication should be used after optimize_transpose
2830
# for optimize_transpose may have some trans nodes that can be merge
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""const dequantize Optimizer.
5+
if a dequantize op's inputs are const we may be able to fold it through the next op
6+
"""
7+
8+
from .optimizer_base import GraphOptimizerBase
9+
from .const_fold_optimizer import ConstFoldOptimizer
10+
11+
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
12+
13+
14+
class ConstDequantizeOptimizer(GraphOptimizerBase):
15+
16+
def __init__(self): # pylint: disable=useless-super-delegation
17+
super(ConstDequantizeOptimizer, self).__init__()
18+
19+
def _optimize(self, graph):
20+
return self._apply_optimization(graph, self._optimize_at_current_graph_level)
21+
22+
def _optimize_at_current_graph_level(self, graph):
23+
graph_changed = True
24+
while graph_changed:
25+
graph_changed = False
26+
ops = graph.get_nodes()
27+
for op in ops:
28+
if self._fold_node(op, graph):
29+
graph_changed = True
30+
self.graph_been_opt = True
31+
return graph
32+
33+
def _fold_node(self, node, graph):
34+
""" if a dequantize op's inputs are const and it is fed into a tensor reshaping op, we can apply the op
35+
directly to the quantized inputs. Returns True if the graph is changed.
36+
"""
37+
if node.type not in ["Transpose", "Reshape", "Unsqueeze"]:
38+
return False
39+
dequant_node = node.inputs[0]
40+
if dequant_node.type != "DequantizeLinear":
41+
return False
42+
if len(graph.find_output_consumers(dequant_node.output[0])) > 1:
43+
return False
44+
if not self._all_inputs_are_const(node.inputs[1:]) or self._is_graph_output(node, graph):
45+
return False
46+
if not self._all_inputs_are_const(dequant_node.inputs):
47+
return False
48+
graph.replace_input(node, node.input[0], dequant_node.input[0], 0)
49+
const_outputs = ConstFoldOptimizer.compute_const_folding(node, graph)
50+
graph.replace_all_inputs(node.output[0], dequant_node.output[0])
51+
graph.remove_node(node.name)
52+
dequant_const = dequant_node.inputs[0]
53+
if len(graph.find_output_consumers(dequant_const.output[0])) > 1:
54+
dequant_const = graph.copy_const(dequant_const)
55+
graph.replace_input(dequant_node, dequant_node.input[0], dequant_const.output[0], 0)
56+
dequant_const.set_tensor_value(const_outputs[0])
57+
return True
58+
59+
@staticmethod
60+
def _all_inputs_are_const(nodes):
61+
return all(node.is_const() for node in nodes if node)
62+
63+
@staticmethod
64+
def _is_graph_output(node, graph):
65+
node_out_set = set(node.output)
66+
graph_out_set = set(graph.outputs)
67+
return node_out_set.intersection(graph_out_set)

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _should_skip(node):
5454
if node.is_const() or node.is_graph_input():
5555
return True
5656

57-
skip_type = ["Identity"]
57+
skip_type = ["Identity", "DequantizeLinear"]
5858
if node.type in skip_type:
5959
return True
6060

@@ -73,6 +73,10 @@ def _fold_node(self, node, graph):
7373
self.logger.debug("need to add function to fold op %s whose op_type is %s", node.name, node.type)
7474
return False
7575

76+
@staticmethod
77+
def compute_const_folding(node, graph):
78+
return _func_map[node.type](node, graph)
79+
7680
@staticmethod
7781
def _all_inputs_are_const(nodes):
7882
return all(node.is_const() for node in nodes if node)

0 commit comments

Comments
 (0)