Skip to content

Commit f545163

Browse files
Handle Abs in TransposeOptimizer (#1699)
Signed-off-by: Mateusz Tabaka <[email protected]>
1 parent 01ec092 commit f545163

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

tests/test_optimizers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,27 @@ def test_transpose_with_concat(self, input_shape, perm, inner_perm):
116116
}
117117
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=1)
118118

119+
@parameterized.expand([
120+
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
121+
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
122+
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
123+
])
124+
def test_transpose_abs(self, shape, perm_input, perm_output):
125+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans1")
126+
node1 = helper.make_node("Abs", ["Y"], ["Z"], name="abs")
127+
node2 = helper.make_node("Transpose", ["Z"], ["OUT"], perm=perm_output, name="trans2")
128+
129+
graph = helper.make_graph(
130+
[node0, node1, node2],
131+
"transpose-abs-test",
132+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
133+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, shape)],
134+
)
135+
136+
model_proto = self.make_model(graph, producer_name="onnx-tests")
137+
self.run_transpose_compare(["OUT"], {"X": np.random.randn(*shape).astype(np.float32)},
138+
model_proto, remaining_transpose_num=0)
139+
119140
@parameterized.expand([
120141
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
121142
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def _optimize_at_current_graph_level(self, graph):
193193

194194
def _initialize_handlers(self):
195195
self._handler_map = {
196+
"Abs": self._simple_through_handler,
196197
"Add": self._add_handler,
197198
"ArgMax": self._arg_min_max_handler,
198199
"ArgMin": self._arg_min_max_handler,

0 commit comments

Comments
 (0)