Skip to content

Commit 712518f

Browse files
authored
Merge pull request #884 from jignparm/jignparm/transpose_sum_optimizer
Adds Sum(Transpose(x1), Transpose(x2),...) optimizer.
2 parents 4767711 + c765ef9 commit 712518f

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

tests/test_optimizers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,30 @@ def test_two_transposes_switch_with_mul(self):
712712
"u2": np.random.randn(1, 6, 8, 9).astype(np.float32)},
713713
model_proto, remaining_transpose_num=0)
714714

715+
def test_many_transposes_and_constant_switch_with_sum(self):
716+
constnode = self._make_onnx_const(np.array(np.random.random((1, 8, 9, 6)), dtype=np.float32), "v4")
717+
node0 = helper.make_node("Transpose", ["u1"], ["v1"], perm=[0, 2, 3, 1], name="trans_0")
718+
node1 = helper.make_node("Transpose", ["u2"], ["v2"], perm=[0, 2, 3, 1], name="trans_1")
719+
node11 = helper.make_node("Transpose", ["u3"], ["v3"], perm=[0, 2, 3, 1], name="trans_2")
720+
721+
node2 = helper.make_node("Sum", ["v1", "v2", "v3", "v4"], ["x"], name="sum_1")
722+
node3 = helper.make_node("Sum", ["x", "v1"], ["y"], name="sum_2")
723+
node4 = helper.make_node("Transpose", ["y"], ["res"], perm=[0, 3, 1, 2], name="trans_4")
724+
725+
graph = helper.make_graph(
726+
[constnode, node0, node1, node11, node2, node3, node4],
727+
"test-transpose-mul",
728+
[helper.make_tensor_value_info("u1", TensorProto.FLOAT, (1, 6, 8, 9)),
729+
helper.make_tensor_value_info("u2", TensorProto.FLOAT, (1, 6, 8, 9)),
730+
helper.make_tensor_value_info("u3", TensorProto.FLOAT, (1, 6, 8, 9))],
731+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 6, 8, 9))],
732+
)
733+
model_proto = self.make_model(graph, producer_name="onnx-tests")
734+
self.run_transpose_compare(["res"], {"u1": np.random.randn(1, 6, 8, 9).astype(np.float32),
735+
"u2": np.random.randn(1, 6, 8, 9).astype(np.float32),
736+
"u3": np.random.randn(1, 6, 8, 9).astype(np.float32)},
737+
model_proto, remaining_transpose_num=0)
738+
715739
# Tranpose Optimizer Tests End
716740

717741
# Identity Optimizer Tests Start

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def _initialize_handlers(self):
184184
"ReduceMean": self._reducemean_handler,
185185
"Relu": self._simple_through_handler,
186186
"Shape": self._shape_handler,
187+
"Sigmoid": self._simple_through_handler,
188+
"Sum": self._sum_handler,
187189
"Slice": self._slice_handler,
188190
"Split": self._split_handler,
189191
"Squeeze": self._squeeze_handler,
@@ -477,6 +479,50 @@ def _mul_handler(self, trans, node):
477479

478480
return False
479481

482+
def _sum_handler(self, trans, node):
483+
inputs = node.inputs
484+
trans_shape = self._g.get_shape(trans.output[0])
485+
perm = list(trans.get_attr('perm').ints)
486+
untrans_idx = [perm.index(i) for i in range(len(perm))]
487+
488+
# check if sum(trans(x1), trans(x2), const(x3), ...) can be switched
489+
for n in inputs:
490+
if n.type not in ["Transpose", "Const"]:
491+
return False
492+
if not self._nodes_has_single_consumer_node([n]):
493+
return False
494+
if n.is_const():
495+
# if graph is valid, op shapes should be valid
496+
# const is special case, in case of broadcasting
497+
# ensure rank matches
498+
n_shape = self._g.get_shape(n.output[0])
499+
if len(n_shape) != len(trans_shape):
500+
return False
501+
else:
502+
if list(n.get_attr('perm').ints) != perm:
503+
return False
504+
505+
# switch to trans(sum(x1, x2, x3, ...))
506+
ops = self._g.get_nodes()
507+
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
508+
node.input = [n.output[0] if n.is_const() else n.input[0] for n in inputs]
509+
trans.input[0] = node.output[0]
510+
511+
# adjust shape if present
512+
shape = self._g.get_shape(node.output[0])
513+
if shape:
514+
self._g.set_shape(node.output[0], [shape[i] for i in untrans_idx])
515+
516+
# update constants, remove dangling transposes
517+
for n in inputs:
518+
if n.is_const():
519+
val = n.get_tensor_value(as_list=False)
520+
new_val = np.transpose(val, untrans_idx)
521+
n.set_tensor_value(new_val)
522+
elif n.name != trans.name:
523+
self._g.remove_node(n.name)
524+
return True
525+
480526
def _identity_handler(self, trans, node):
481527
if node.output[0] in node.graph.outputs:
482528
return False

0 commit comments

Comments
 (0)