Skip to content

Commit 41fad17

Browse files
authored
Merge pull request #619 from zhijxu-MS/add_transpose_sub
add trans opt for sub
2 parents 244e90a + 75eea46 commit 41fad17

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

tests/test_optimizers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,32 @@ def _make_loop(external_inputs, outputs):
368368
self.run_transpose_compare(["Y"], {"array": np.random.randn(10, 3, 4, 5).astype(np.float32)},
369369
model_proto, remaining_transpose_num=0)
370370

371+
def test_trans_with_sub(self):
372+
io_shape = [2, 3, 4, 5]
373+
const_shapes = [[2, 4, 5, 3], [4, 5, 3], [5, 3], [3]]
374+
for trans_is_first_input in [True, False]:
375+
for const_shape in const_shapes:
376+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_a")
377+
const_tensor = helper.make_tensor(name='const', data_type=TensorProto.FLOAT, dims=const_shape,
378+
vals=np.random.randn(*const_shape).flatten().astype(np.float32))
379+
node2 = helper.make_node("Constant", [], ["const"], value=const_tensor, name="const")
380+
if trans_is_first_input:
381+
node3 = helper.make_node("Sub", ["Y", "const"], ["Z"], name="sub")
382+
else:
383+
node3 = helper.make_node("Sub", ["const", "Y"], ["Z"], name="sub")
384+
385+
node4 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_b")
386+
graph = helper.make_graph(
387+
[node1, node2, node3, node4],
388+
"test_trans_with_sub",
389+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, io_shape)],
390+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, io_shape)],
391+
)
392+
393+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
394+
self.run_transpose_compare(["res"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
395+
model_proto, remaining_transpose_num=0)
396+
371397
def test_trans_output_as_graph_outputs(self):
372398
"""
373399
If transpose's output is graph's output, don't optimize it.

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def _initialize_handlers(self):
186186
"Slice": self._slice_handler,
187187
"Split": self._split_handler,
188188
"Squeeze": self._squeeze_handler,
189+
"Sub": self._sub_handler,
189190
"Tanh": self._simple_through_handler,
190191
"Transpose": self._transpose_handler,
191192
}
@@ -522,6 +523,9 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
522523
return True
523524
return False
524525

526+
def _sub_handler(self, trans, node):
527+
return self._handle_node_having_branches(node)
528+
525529
def _pad_handler(self, trans, node):
526530
# [N-start, H-start, W-start, C-start, N-end, H-end, W-end, C-end]
527531
pads = node.get_attr('pads').ints # [x1_begin, x2_begin...x1_end, x2_end,...]

0 commit comments

Comments
 (0)