Skip to content

Commit b2cb5e4

Browse files
committed
change a part of add handler where transpose is after convolution
add test cases for add handler
1 parent b23377c commit b2cb5e4

File tree

2 files changed

+94
-2
lines changed

2 files changed

+94
-2
lines changed

tests/test_optimizers.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,91 @@ def test_trans_with_sub_input_non_const(self):
445445
"non_const": np.random.randn(*non_const_shape).astype(np.float32)},
446446
model_proto, remaining_transpose_num=1)
447447

448+
def test_transpose_add_with_input_non_const(self):
449+
450+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
451+
node1 = helper.make_node("Add", ["Y", "A"], ["Z"], name="add")
452+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
453+
454+
graph = helper.make_graph(
455+
[node0, node1, node2],
456+
"transpose-add-test-input-non-const",
457+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 1, 3, 3)),
458+
helper.make_tensor_value_info("A", TensorProto.FLOAT, (1, 3, 3, 1))],
459+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 1, 3, 3))],
460+
)
461+
462+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
463+
self.run_transpose_compare(["res"], {"X": np.random.randn(1, 1, 3, 3).astype(np.float32),
464+
"A": np.random.randn(1, 3, 3, 1).astype(np.float32)},
465+
model_proto, remaining_transpose_num=0)
466+
467+
def test_transpose_add_with_input_const(self):
468+
const_1_val = np.random.randn(1, 3, 3, 1).astype(np.float32).reshape(9).tolist()
469+
const_1 = helper.make_tensor("const_1", TensorProto.FLOAT, (1, 3, 3, 1), const_1_val)
470+
const_1_node = helper.make_node("Constant", [], ["const_1"], value=const_1, name="const_1")
471+
472+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
473+
node1 = helper.make_node("Add", ["Y", "const_1"], ["Z"], name="add")
474+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
475+
476+
graph = helper.make_graph(
477+
[const_1_node, node0, node1, node2],
478+
"transpose-add-test-input-const",
479+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 1, 3, 3))],
480+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 1, 3, 3))],
481+
)
482+
483+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
484+
self.run_transpose_compare(["res"], {"X": np.random.randn(1, 1, 3, 3).astype(np.float32)},
485+
model_proto, remaining_transpose_num=0)
486+
487+
def test_transpose_add_with_conv_1(self):
488+
const_b_val = np.random.randn(1, 1, 1, 16).astype(np.float32).reshape(16).tolist()
489+
const_b = helper.make_tensor("const_b", TensorProto.FLOAT, (1, 1, 1, 16), const_b_val)
490+
const_b_node = helper.make_node("Constant", [], ["const_b"], value=const_b, name="const_b")
491+
492+
node0 = helper.make_node("Conv", ["x", "W"], ["X"], name="conv", pads=[0, 0, 0, 0])
493+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
494+
node2 = helper.make_node("Add", ["Y", "const_b"], ["Z"], name="add")
495+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
496+
497+
graph = helper.make_graph(
498+
[ const_b_node, node0, node1, node2, node3],
499+
"transpose-add-test-with-conv-1",
500+
[helper.make_tensor_value_info("x", TensorProto.FLOAT, (1, 5, 3, 3)),
501+
helper.make_tensor_value_info("W", TensorProto.FLOAT, (16, 5, 3, 3))],
502+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 16, 1, 1))],
503+
)
504+
505+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
506+
self.run_transpose_compare(["res"], {"x": np.random.randn(1, 5, 3, 3).astype(np.float32),
507+
"W": np.random.randn(16, 5, 3, 3).astype(np.float32)},
508+
model_proto, remaining_transpose_num=0)
509+
510+
def test_transpose_add_with_conv_2(self):
511+
const_b_val = np.random.randn(1, 3, 3, 1).astype(np.float32).reshape(9).tolist()
512+
const_b = helper.make_tensor("const_b", TensorProto.FLOAT, (1, 3, 3, 1), const_b_val)
513+
const_b_node = helper.make_node("Constant", [], ["const_b"], value=const_b, name="const_b")
514+
515+
node0 = helper.make_node("Conv", ["x", "W"], ["X"], name="conv", pads=[0, 0, 0, 0])
516+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
517+
node2 = helper.make_node("Add", ["Y", "const_b"], ["Z"], name="add")
518+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
519+
520+
graph = helper.make_graph(
521+
[const_b_node, node0, node1, node2, node3],
522+
"transpose-add-test-with-conv-2",
523+
[helper.make_tensor_value_info("x", TensorProto.FLOAT, (1, 1, 5, 5)),
524+
helper.make_tensor_value_info("W", TensorProto.FLOAT, (1, 1, 3, 3))],
525+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 1, 3, 3))],
526+
)
527+
528+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
529+
self.run_transpose_compare(["res"], {"x": np.random.randn(1, 1, 5, 5).astype(np.float32),
530+
"W": np.random.randn(1, 1, 3, 3).astype(np.float32)},
531+
model_proto, remaining_transpose_num=0)
532+
448533
def test_trans_output_as_graph_outputs(self):
449534
"""
450535
If transpose's output is graph's output, don't optimize it.

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,16 +370,23 @@ def _add_handler(self, trans, node):
370370
if t_p.type in ("Conv", "ConvTranspose") and len(t_p.input) == 2:
371371
# if Conv or ConvTranspose's bias input is not set, then we set, otherwise, we don't set
372372
# todo: maybe we can add already set bias with the input??? try later
373+
374+
target_node = node.inputs[1]
375+
numpy_val = target_node.get_tensor_value(as_list=False)
376+
# Optional 1D bias to be added to the convolution, has size of M
377+
if len(numpy_val.shape) - numpy_val.shape.count(1) > 1:
378+
return self._handle_node_having_branches(node)
379+
transposed_val = np.transpose(numpy_val, (0, 3, 1, 2))
380+
target_node.set_tensor_value(transposed_val)
381+
373382
conv_inputs = [t_p.input[0], t_p.input[1], node.input[1]]
374383
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.attr_onnx)
375384
ops = self._g.get_nodes()
376385
trans.input[0] = utils.port_name(conv_node.name)
377386
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
378-
379387
self._g.remove_node(t_p.name)
380388
self._g.remove_node(node.name)
381389
return True
382-
return False
383390
return self._handle_node_having_branches(node)
384391

385392
def _transpose_handler(self, trans, node):

0 commit comments

Comments
 (0)