Skip to content

Commit e8842ba

Browse files
Use simple_handler for Reciprocal,Sqrt in TransposeOptimizer (#1340)
Signed-off-by: Mateusz Tabaka <[email protected]>
1 parent c24dc6f commit e8842ba

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

tests/test_optimizers.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,26 @@ def test_transpose_with_identity(self, input_shape, output_shape, perm):
515515
self.run_transpose_compare(["Z"], {"X": np.random.randn(*input_shape).astype(np.float32)},
516516
model_proto, remaining_transpose_num=1)
517517

518+
@parameterized.expand([
519+
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
520+
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
521+
])
522+
def test_transpose_sqrt(self, shape, perm_input, perm_output):
523+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans1")
524+
node1 = helper.make_node("Sqrt", ["Y"], ["Z"], name="sqrt")
525+
node2 = helper.make_node("Transpose", ["Z"], ["OUT"], perm=perm_output, name="trans2")
526+
527+
graph = helper.make_graph(
528+
[node0, node1, node2],
529+
"transpose-sqrt-test",
530+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
531+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, shape)],
532+
)
533+
534+
model_proto = self.make_model(graph, producer_name="onnx-tests")
535+
self.run_transpose_compare(["OUT"], {"X": np.random.randn(*shape).astype(np.float32)},
536+
model_proto, remaining_transpose_num=0)
537+
518538
@parameterized.expand([
519539
((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]),
520540
((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]),
@@ -1000,6 +1020,27 @@ def test_transpose_pad11_non_const_pads(self, input_shape, output_shape, pads, p
10001020
},
10011021
model_proto, remaining_transpose_num=0)
10021022

1023+
@parameterized.expand([
1024+
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
1025+
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
1026+
])
1027+
def test_transpose_reciprocal(self, shape, perm_input, perm_output):
1028+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans1")
1029+
node1 = helper.make_node("Reciprocal", ["Y"], ["Z"], name="reciprocal")
1030+
node2 = helper.make_node("Transpose", ["Z"], ["OUT"], perm=perm_output, name="trans2")
1031+
1032+
graph = helper.make_graph(
1033+
[node0, node1, node2],
1034+
"transpose-reciprocal-test",
1035+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
1036+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, shape)],
1037+
)
1038+
1039+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1040+
self.run_transpose_compare(["OUT"], {"X": np.random.randn(*shape).astype(np.float32)},
1041+
model_proto, remaining_transpose_num=0)
1042+
1043+
10031044
@parameterized.expand([
10041045
((1, 3, 4, 5), (1, 3, 1, 1), [0, 2, 3, 1], [0, 3, 1, 2]),
10051046
((1, 3, 4, 5, 6), (1, 3, 1, 1, 1), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def _initialize_handlers(self):
191191
"Min": self._maxmin_handler,
192192
"Mul": self._mul_handler,
193193
"Pad": self._pad_handler,
194+
"Reciprocal": self._simple_through_handler,
194195
"ReduceMean": self._reducemean_handler,
195196
"Relu": self._simple_through_handler,
196197
"Shape": self._shape_handler,
@@ -199,6 +200,7 @@ def _initialize_handlers(self):
199200
"Slice": self._slice_handler,
200201
"Split": self._split_handler,
201202
"Softplus": self._simple_through_handler,
203+
"Sqrt": self._simple_through_handler,
202204
"Squeeze": self._squeeze_handler,
203205
"Sub": self._sub_handler,
204206
"Tanh": self._simple_through_handler,

0 commit comments

Comments
 (0)