Skip to content

Commit 9c7b56c

Browse files
Add optimizer to push q and dq ops into place (#1497)
* Add optimizer to push q and dq ops into place Signed-off-by: Tom Wildenhain <[email protected]> * Increase ci pipeline ort version Signed-off-by: Tom Wildenhain <[email protected]> * Update min opset Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 93d2437 commit 9c7b56c

File tree

6 files changed

+238
-5
lines changed

6 files changed

+238
-5
lines changed

ci_build/azure_pipelines/templates/job_generator.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ parameters:
66
tf_versions: ['']
77
onnx_versions: ['']
88
onnx_opsets: ['13', '12', '11', '10', '9', '8', '7']
9-
onnx_backends: {onnxruntime: ['1.6.0']}
9+
onnx_backends: {onnxruntime: ['1.7.0']}
1010
job: {}
1111
run_setup: 'True'
1212
report_coverage: 'False'

tests/test_backend.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2734,6 +2734,71 @@ def func(x):
27342734
return tf.identity(x_, name=_TFOUTPUT)
27352735
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
27362736

2737+
@check_tf_min_version("1.15")
2738+
@check_opset_min_version(10, "quantize_and_dequantize")
2739+
def test_qdq_optimizer(self):
2740+
x_shape = [3, 3, 2]
2741+
x_val = np.arange(1, 1+np.prod(x_shape)).astype("float32").reshape(x_shape)
2742+
def func(x):
2743+
x_ = quantize_and_dequantize(x, 1.0, 6.0, signed_input=False, range_given=True)
2744+
x_ = tf.transpose(x_, [1, 2, 0])
2745+
x_ = tf.reshape(x_, tf.constant([9, 2]))
2746+
x_ = quantize_and_dequantize(x_, 1.0, 6.0, signed_input=False, range_given=True)
2747+
return tf.identity(x_, name=_TFOUTPUT)
2748+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val},
2749+
graph_validator=lambda g: check_op_count(g, "DequantizeLinear", 1, disabled=False))
2750+
2751+
@check_tf_min_version("1.15")
2752+
@check_opset_min_version(10, "quantize_and_dequantize")
2753+
def test_qdq_optimizer_split_concat(self):
2754+
x_shape = [7, 3, 5]
2755+
y_shape = [7, 2, 5]
2756+
x_val = np.arange(1, 1+np.prod(x_shape)).astype("float32").reshape(x_shape)
2757+
y_val = np.arange(1, 1+np.prod(y_shape)).astype("float32").reshape(y_shape)
2758+
def func(x, y):
2759+
x_ = quantize_and_dequantize(x, 1.0, 30.0, signed_input=False, range_given=True)
2760+
a, _, c = tf.unstack(x_, axis=1)
2761+
ac = tf.stack([a, c], axis=1)
2762+
y_ = quantize_and_dequantize(y, 1.0, 30.0, signed_input=False, range_given=True)
2763+
m = tf.matmul(ac, tf.transpose(y_, [0, 2, 1]))
2764+
m_ = m[2:, :, :]
2765+
m_ = quantize_and_dequantize(m_, 1.0, 30.0, signed_input=False, range_given=True)
2766+
return tf.identity(m_, name=_TFOUTPUT)
2767+
def validate_graph(g):
2768+
# MatMul should be wrapped in Dq/Q
2769+
for n in g.get_nodes():
2770+
if n.type == "MatMul":
2771+
if not all(inp.type == "DequantizeLinear" for inp in n.inputs):
2772+
return False
2773+
if not all(c.type == "QuantizeLinear" for c in g.find_output_consumers(n.output[0])):
2774+
return False
2775+
return True
2776+
2777+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val}, graph_validator=validate_graph)
2778+
2779+
@check_tf_min_version("1.15")
2780+
@check_opset_min_version(11, "ScatterND")
2781+
def test_qdq_optimizer_scatter(self):
2782+
x_val = np.array([10, 20, 30, 40], dtype=np.float32).reshape((4))
2783+
y_val = np.array([0, 2], dtype=np.int64).reshape((2, 1))
2784+
z_val = np.array([8, 11], dtype=np.float32).reshape((2))
2785+
2786+
def func(x, y, z):
2787+
x_ = quantize_and_dequantize(x, 1.0, 30.0, signed_input=False, range_given=True)
2788+
z_ = quantize_and_dequantize(z, 1.0, 30.0, signed_input=False, range_given=True)
2789+
w = tf.tensor_scatter_nd_update(x_, y, z_)
2790+
w_ = quantize_and_dequantize(w, 1.0, 30.0, signed_input=False, range_given=True)
2791+
return tf.identity(w_, name=_TFOUTPUT)
2792+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val},
2793+
graph_validator=lambda g: check_op_count(g, "DequantizeLinear", 1, disabled=False))
2794+
2795+
def func(x, y, z):
2796+
x_ = quantize_and_dequantize(x, 1.0, 30.0, signed_input=False, range_given=True)
2797+
w = tf.tensor_scatter_nd_update(x_, y, z)
2798+
w_ = quantize_and_dequantize(w, 1.0, 30.0, signed_input=False, range_given=True)
2799+
return tf.identity(w_, name=_TFOUTPUT)
2800+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
2801+
27372802
@check_tf_min_version("1.15")
27382803
@check_opset_min_version(10, "quantize_and_dequantize")
27392804
def test_qdq_dyn_range_unsigned_input(self):

tf2onnx/optimizer/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .const_dequantize_optimizer import ConstDequantizeOptimizer
2020
from .reshape_optimizer import ReshapeOptimizer
2121
from .global_pool_optimizer import GlobalPoolOptimizer
22+
from .q_dq_optimizer import QDQOptimizer
2223
from .. import logging
2324

2425
# optimizer sequence need to be considered carefully
@@ -32,9 +33,10 @@
3233
# for optimize_transpose may have some trans nodes that can be merge
3334
("merge_duplication", MergeDuplicatedNodesOptimizer),
3435
("reshape_optimizer", ReshapeOptimizer),
36+
("global_pool_optimizer", GlobalPoolOptimizer),
37+
("q_dq_optimizer", QDQOptimizer),
3538
("remove_identity", IdentityOptimizer),
3639
("remove_back_to_back", BackToBackOptimizer),
37-
("global_pool_optimizer", GlobalPoolOptimizer),
3840
])
3941

4042

@@ -50,6 +52,7 @@ def optimize_graph(graph, catch_errors=True):
5052
before = graph.dump_node_statistics()
5153
opts = _get_optimizers()
5254
continue_flag = True
55+
iteration = 0
5356
while continue_flag:
5457
continue_flag = False
5558
for name, factory in opts.items():
@@ -58,15 +61,16 @@ def optimize_graph(graph, catch_errors=True):
5861
try:
5962
current = copy.deepcopy(graph)
6063
opt = factory()
61-
graph = opt.optimize(current) or graph
64+
graph = opt.optimize(current, iteration) or graph
6265
continue_flag = continue_flag or opt.graph_been_opt
6366
except Exception: # pylint: disable=broad-except
6467
# if current optimizer fails, continue with other optimizers
6568
logger.warning("Failed to apply %s", name, exc_info=1)
6669
else:
6770
opt = factory()
68-
graph = opt.optimize(graph)
71+
graph = opt.optimize(graph, iteration)
6972
continue_flag = continue_flag or opt.graph_been_opt
73+
iteration += 1
7074

7175
try:
7276
graph.topological_sort(graph.get_nodes())

tf2onnx/optimizer/optimizer_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class GraphOptimizerBase(object):
1717
def __init__(self):
1818
self._logger = logging.getLogger('.'.join(__name__.split('.')[:-1] + [self.__class__.__name__]))
1919
self._graph_been_opt = False
20+
self.opt_iteration = 0
2021

2122
@property
2223
def logger(self):
@@ -34,10 +35,11 @@ def graph_been_opt(self):
3435
def graph_been_opt(self, value):
3536
self._graph_been_opt = value
3637

37-
def optimize(self, graph):
38+
def optimize(self, graph, iteration):
3839
""" Optimize graph, return optimized graph. """
3940
before = graph.dump_node_statistics()
4041

42+
self.opt_iteration = iteration
4143
graph = self._optimize(graph)
4244
graph.update_proto()
4345
graph.delete_unused_nodes(graph.outputs)

tf2onnx/optimizer/q_dq_optimizer.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""q dq optimizer
5+
Pushes Quantize ops up and Dequantize ops down to maximize DQ -> op -> Q patterns for ORT
6+
Does not work for per-channel quantization yet
7+
"""
8+
9+
from .optimizer_base import GraphOptimizerBase
10+
11+
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
12+
13+
14+
class QDQOptimizer(GraphOptimizerBase):
15+
16+
def __init__(self): # pylint: disable=useless-super-delegation
17+
super(QDQOptimizer, self).__init__()
18+
19+
def _optimize(self, graph):
20+
return self._apply_optimization(graph, self._optimize_at_current_graph_level)
21+
22+
def _optimize_at_current_graph_level(self, graph):
23+
graph_changed = True
24+
while graph_changed:
25+
graph_changed = False
26+
ops = graph.get_nodes()
27+
for op in ops:
28+
if op.type == "QuantizeLinear" and self._optimize_quantize(op, graph):
29+
graph_changed = True
30+
self.graph_been_opt = True
31+
elif op.type == "DequantizeLinear" and self._optimize_dequantize(op, graph):
32+
graph_changed = True
33+
self.graph_been_opt = True
34+
return graph
35+
36+
def _optimize_quantize(self, quant_node, graph):
37+
if 'axis' in quant_node.attr:
38+
return False
39+
node = quant_node.inputs[0]
40+
if node.type == "DequantizeLinear":
41+
# Remove DQ -> Q
42+
if not self.has_same_quantization_params(quant_node, node):
43+
return False
44+
if quant_node.output[0] in graph.outputs or node.output[0] in graph.outputs:
45+
return False
46+
graph.replace_all_inputs(quant_node.output[0], node.input[0])
47+
if not graph.find_output_consumers(quant_node.output[0]):
48+
graph.remove_node(quant_node.name)
49+
if not graph.find_output_consumers(node.output[0]):
50+
graph.remove_node(node.name)
51+
return True
52+
53+
# Push quantize nodes up
54+
tensor_idx = is_tensor_op(graph, node)
55+
if tensor_idx is None:
56+
return False
57+
inp_indices, out_indices = tensor_idx
58+
for i in out_indices:
59+
consumers = graph.find_output_consumers(node.output[i])
60+
if node.output[i] in graph.outputs:
61+
return False
62+
for c in consumers:
63+
if c.type != "QuantizeLinear":
64+
return False
65+
if not self.has_same_quantization_params(c, quant_node):
66+
return False
67+
if c.output[0] in graph.outputs:
68+
return False
69+
# All outputs are quantized. Push quantization up to input.
70+
for i in inp_indices:
71+
inp_q = self.make_q_or_dq(graph, "QuantizeLinear", node.input[i], quant_node, node.name)
72+
graph.replace_input(node, node.input[i], inp_q.output[0], i)
73+
74+
for i in out_indices:
75+
graph.copy_dtype(quant_node.output[0], node.output[i])
76+
consumers = graph.find_output_consumers(node.output[i])
77+
for c in consumers:
78+
graph.replace_all_inputs(c.output[0], node.output[i])
79+
80+
return True
81+
82+
def _optimize_dequantize(self, dequant_node, graph):
83+
if 'axis' in dequant_node.attr:
84+
return False
85+
# Push dequantize nodes down
86+
consumers = graph.find_output_consumers(dequant_node.output[0])
87+
for node in consumers:
88+
if self._optimize_dequantize_and_node(dequant_node, node, graph):
89+
return True
90+
return False
91+
92+
def _optimize_dequantize_and_node(self, dequant_node, node, graph):
93+
tensor_idx = is_tensor_op(graph, node)
94+
if tensor_idx is None:
95+
return False
96+
inp_indices, out_indices = tensor_idx
97+
for i in inp_indices:
98+
inp = node.inputs[i]
99+
if inp.type != "DequantizeLinear":
100+
return False
101+
if not self.has_same_quantization_params(inp, dequant_node):
102+
return False
103+
if inp.output[0] in graph.outputs:
104+
return False
105+
for i in out_indices:
106+
if node.output[i] in graph.outputs:
107+
return False
108+
# All inputs are dequantized. Push dequantization down to output.
109+
for i in inp_indices:
110+
# Skip the dequantize on the input
111+
graph.replace_input(node, node.input[i], node.inputs[i].input[0], i)
112+
113+
for i in out_indices:
114+
graph.copy_dtype(dequant_node.input[0], node.output[i])
115+
out_dq = self.make_q_or_dq(graph, "DequantizeLinear", node.output[i], dequant_node, node.name)
116+
graph.insert_node_on_output(out_dq, node.output[i])
117+
118+
return True
119+
120+
def has_same_quantization_params(self, node1, node2):
121+
if node1.get_attr_value("axis") != node2.get_attr_value("axis"):
122+
return False
123+
# Constant merging will ensure these are the same nodes if they are equal
124+
return node1.input[1:] == node2.input[1:]
125+
126+
def make_q_or_dq(self, graph, op_type, inp, reference_node, name_scope):
127+
"""Makes a QuantizeLinear or DequantizeLinear with quantization params copied from the reference_node"""
128+
axis = reference_node.get_attr_value("axis")
129+
if axis is None:
130+
attr = {}
131+
else:
132+
attr = {'axis': axis}
133+
return graph.make_node(op_type, [inp] + reference_node.input[1:], attr=attr, op_name_scope=name_scope)
134+
135+
136+
def is_tensor_op(g, node):
137+
"""Detects ops that reshape/shuffle tensor elements without computing/changing them (Transpose, Gather, etc.)
138+
Returns None or a tuple (inp_indices, out_indices) s.t. all corresponding outputs of the node depend only
139+
on elements of the corresponding inputs of the node and all other inputs/outputs are unchanged.
140+
WARNING: Transpose optimizer pushes tranpose down so be careful when swapping to avoid infinite loop."""
141+
if node.type in ["Identity", "Reshape", "Flatten", "Expand", "Transpose", "Squeeze", "Unsqueeze", "Slice"]:
142+
return ([0], [0])
143+
if node.type in ["Gather", "GatherND", "GatherElements"]:
144+
# Output depends on data if indices is unchanged
145+
return ([0], [0])
146+
if node.type in ["Scatter", "ScatterND", "ScatterElements"]:
147+
# Output depends on data and updates if indices is unchanged
148+
return ([0, 2], [0])
149+
if node.type == "Concat":
150+
return (list(range(len(node.input))), [0])
151+
if node.type == "Split":
152+
return ([0], list(range(len(node.output))))
153+
if node.type in ["Compress", "Tile", "ReverseSequence", "DepthToSpace"]:
154+
return ([0], [0])
155+
return None

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,13 @@ def _slice_handler(self, trans, node):
916916

917917
def _quantize_handler(self, trans, node):
918918
# Used for QuantizeLinear and DequantizeLinear
919+
if node.type == "DequantizeLinear":
920+
# Only push through if we will be able to push through consumers too.
921+
cons = self._g.find_output_consumers(node.output[0])
922+
# If there is a false positive in the handler map, the q_dq and transpose optimizers might fight.
923+
# Give up after 3 iterations. The q_dq optimizer should win so the dq hugs the op.
924+
if not all(n.type in self._handler_map for n in cons) or self.opt_iteration >= 3:
925+
return False
919926
if not self._switch_transpose_and_node(node, trans):
920927
return False
921928
if 'axis' in node.attr:

0 commit comments

Comments
 (0)