Skip to content

Commit c420f3c

Browse files
Extend transpose optimizer for quantize and dequantize (#1314)
* Extend transpose optimizer for quantize and dequantize Signed-off-by: Tom Wildenhain <[email protected]> * Fix style Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 96e1a03 commit c420f3c

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

tests/test_optimizers.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,86 @@ def test_transpose_leaky_relu(self):
191191
self.run_transpose_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
192192
model_proto, remaining_transpose_num=0)
193193

194+
@check_opset_min_version(10, "QuantizeLinear")
195+
def test_transpose_quantize(self):
196+
scale = numpy_helper.from_array(np.array(0.75, dtype=np.float32), name='scale')
197+
zero_point = numpy_helper.from_array(np.array(3, dtype=np.uint8), name='zero_point')
198+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
199+
node2 = helper.make_node("QuantizeLinear", ["Y", "scale", "zero_point"], ["Z"], name="quantize")
200+
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=[0, 3, 1, 2], name="trans_2")
201+
202+
graph = helper.make_graph(
203+
[node1, node2, node3],
204+
"quantize-test",
205+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
206+
[helper.make_tensor_value_info("Z1", TensorProto.UINT8, (2, 3, 4, 5))],
207+
[scale, zero_point]
208+
)
209+
210+
model_proto = self.make_model(graph, producer_name="onnx-tests")
211+
self.run_transpose_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
212+
model_proto, remaining_transpose_num=0)
213+
214+
@check_opset_min_version(13, "QuantizeLinear with axis")
215+
def test_transpose_quantize_with_axis(self):
216+
scale = numpy_helper.from_array(np.array([0.75, 0.1, 2.3, 0.3, 0.42], dtype=np.float32), name='scale')
217+
zero_point = numpy_helper.from_array(np.array([2, 4, 6, 8, 10], dtype=np.uint8), name='zero_point')
218+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
219+
node2 = helper.make_node("QuantizeLinear", ["Y", "scale", "zero_point"], ["Z"], name="quantize", axis=2)
220+
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=[0, 3, 1, 2], name="trans_2")
221+
222+
graph = helper.make_graph(
223+
[node1, node2, node3],
224+
"quantize-test",
225+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
226+
[helper.make_tensor_value_info("Z1", TensorProto.UINT8, (2, 3, 4, 5))],
227+
[scale, zero_point]
228+
)
229+
230+
model_proto = self.make_model(graph, producer_name="onnx-tests")
231+
self.run_transpose_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
232+
model_proto, remaining_transpose_num=0)
233+
234+
@check_opset_min_version(10, "DequantizeLinear")
235+
def test_transpose_dequantize(self):
236+
scale = numpy_helper.from_array(np.array(0.75, dtype=np.float32), name='scale')
237+
zero_point = numpy_helper.from_array(np.array(3, dtype=np.uint8), name='zero_point')
238+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
239+
node2 = helper.make_node("DequantizeLinear", ["Y", "scale", "zero_point"], ["Z"], name="dequantize")
240+
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=[0, 3, 1, 2], name="trans_2")
241+
242+
graph = helper.make_graph(
243+
[node1, node2, node3],
244+
"dequantize-test",
245+
[helper.make_tensor_value_info("X", TensorProto.UINT8, (2, 3, 4, 5))],
246+
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, (2, 3, 4, 5))],
247+
[scale, zero_point]
248+
)
249+
250+
model_proto = self.make_model(graph, producer_name="onnx-tests")
251+
self.run_transpose_compare(["Z1"], {"X": np.random.randint(0, 100, (2, 3, 4, 5), np.uint8)},
252+
model_proto, remaining_transpose_num=0)
253+
254+
@check_opset_min_version(13, "DequantizeLinear with axis")
255+
def test_transpose_dequantize_with_axis(self):
256+
scale = numpy_helper.from_array(np.array([0.75, 0.1, 2.3, 0.3, 0.42], dtype=np.float32), name='scale')
257+
zero_point = numpy_helper.from_array(np.array([2, 4, 6, 8, 10], dtype=np.uint8), name='zero_point')
258+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
259+
node2 = helper.make_node("DequantizeLinear", ["Y", "scale", "zero_point"], ["Z"], name="dequantize", axis=2)
260+
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=[0, 3, 1, 2], name="trans_2")
261+
262+
graph = helper.make_graph(
263+
[node1, node2, node3],
264+
"dequantize-test",
265+
[helper.make_tensor_value_info("X", TensorProto.UINT8, (2, 3, 4, 5))],
266+
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, (2, 3, 4, 5))],
267+
[scale, zero_point]
268+
)
269+
270+
model_proto = self.make_model(graph, producer_name="onnx-tests")
271+
self.run_transpose_compare(["Z1"], {"X": np.random.randint(0, 100, (2, 3, 4, 5), np.uint8)},
272+
model_proto, remaining_transpose_num=0)
273+
194274
@check_opset_min_version(10, "Slice in opset 10 can accept dymaic 'start' and 'ends'")
195275
def test_transpose_slice(self):
196276
starts = np.array([0, 0, 0, 0], dtype=np.int64)

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ def _initialize_handlers(self):
202202
"Sub": self._sub_handler,
203203
"Tanh": self._simple_through_handler,
204204
"Transpose": self._transpose_handler,
205+
"DequantizeLinear": self._quantize_handler,
206+
"QuantizeLinear": self._quantize_handler,
205207
}
206208

207209
def _handle_node_having_branches(self, node):
@@ -698,6 +700,17 @@ def _slice_handler(self, trans, node):
698700
return self._switch_transpose_and_node(node, trans)
699701
return False
700702

703+
def _quantize_handler(self, trans, node):
704+
# Used for QuantizeLinear and DequantizeLinear
705+
if not self._switch_transpose_and_node(node, trans):
706+
return False
707+
if 'axis' in node.attr:
708+
perm = trans.get_attr_value("perm")
709+
axis = node.get_attr_value("axis")
710+
new_axis = perm[axis]
711+
node.set_attr("axis", new_axis)
712+
return True
713+
701714
def _simple_through_handler(self, trans, node):
702715
return self._switch_transpose_and_node(node, trans)
703716

0 commit comments

Comments
 (0)