Skip to content

Commit a6fa737

Browse files
committed
Merge branch 'master' into gs/api
2 parents 3c0973e + 712518f commit a6fa737

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

tests/test_optimizers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,30 @@ def test_two_transposes_switch_with_mul(self):
712712
"u2": np.random.randn(1, 6, 8, 9).astype(np.float32)},
713713
model_proto, remaining_transpose_num=0)
714714

715+
def test_many_transposes_and_constant_switch_with_sum(self):
716+
constnode = self._make_onnx_const(np.array(np.random.random((1, 8, 9, 6)), dtype=np.float32), "v4")
717+
node0 = helper.make_node("Transpose", ["u1"], ["v1"], perm=[0, 2, 3, 1], name="trans_0")
718+
node1 = helper.make_node("Transpose", ["u2"], ["v2"], perm=[0, 2, 3, 1], name="trans_1")
719+
node11 = helper.make_node("Transpose", ["u3"], ["v3"], perm=[0, 2, 3, 1], name="trans_2")
720+
721+
node2 = helper.make_node("Sum", ["v1", "v2", "v3", "v4"], ["x"], name="sum_1")
722+
node3 = helper.make_node("Sum", ["x", "v1"], ["y"], name="sum_2")
723+
node4 = helper.make_node("Transpose", ["y"], ["res"], perm=[0, 3, 1, 2], name="trans_4")
724+
725+
graph = helper.make_graph(
726+
[constnode, node0, node1, node11, node2, node3, node4],
727+
"test-transpose-mul",
728+
[helper.make_tensor_value_info("u1", TensorProto.FLOAT, (1, 6, 8, 9)),
729+
helper.make_tensor_value_info("u2", TensorProto.FLOAT, (1, 6, 8, 9)),
730+
helper.make_tensor_value_info("u3", TensorProto.FLOAT, (1, 6, 8, 9))],
731+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 6, 8, 9))],
732+
)
733+
model_proto = self.make_model(graph, producer_name="onnx-tests")
734+
self.run_transpose_compare(["res"], {"u1": np.random.randn(1, 6, 8, 9).astype(np.float32),
735+
"u2": np.random.randn(1, 6, 8, 9).astype(np.float32),
736+
"u3": np.random.randn(1, 6, 8, 9).astype(np.float32)},
737+
model_proto, remaining_transpose_num=0)
738+
715739
# Tranpose Optimizer Tests End
716740

717741
# Identity Optimizer Tests Start

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ def _initialize_handlers(self):
185185
"Relu": self._simple_through_handler,
186186
"Sigmoid": self._simple_through_handler,
187187
"Shape": self._shape_handler,
188+
"Sigmoid": self._simple_through_handler,
189+
"Sum": self._sum_handler,
188190
"Slice": self._slice_handler,
189191
"Split": self._split_handler,
190192
"Squeeze": self._squeeze_handler,
@@ -478,6 +480,50 @@ def _mul_handler(self, trans, node):
478480

479481
return False
480482

483+
def _sum_handler(self, trans, node):
484+
inputs = node.inputs
485+
trans_shape = self._g.get_shape(trans.output[0])
486+
perm = list(trans.get_attr('perm').ints)
487+
untrans_idx = [perm.index(i) for i in range(len(perm))]
488+
489+
# check if sum(trans(x1), trans(x2), const(x3), ...) can be switched
490+
for n in inputs:
491+
if n.type not in ["Transpose", "Const"]:
492+
return False
493+
if not self._nodes_has_single_consumer_node([n]):
494+
return False
495+
if n.is_const():
496+
# if graph is valid, op shapes should be valid
497+
# const is special case, in case of broadcasting
498+
# ensure rank matches
499+
n_shape = self._g.get_shape(n.output[0])
500+
if len(n_shape) != len(trans_shape):
501+
return False
502+
else:
503+
if list(n.get_attr('perm').ints) != perm:
504+
return False
505+
506+
# switch to trans(sum(x1, x2, x3, ...))
507+
ops = self._g.get_nodes()
508+
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
509+
node.input = [n.output[0] if n.is_const() else n.input[0] for n in inputs]
510+
trans.input[0] = node.output[0]
511+
512+
# adjust shape if present
513+
shape = self._g.get_shape(node.output[0])
514+
if shape:
515+
self._g.set_shape(node.output[0], [shape[i] for i in untrans_idx])
516+
517+
# update constants, remove dangling transposes
518+
for n in inputs:
519+
if n.is_const():
520+
val = n.get_tensor_value(as_list=False)
521+
new_val = np.transpose(val, untrans_idx)
522+
n.set_tensor_value(new_val)
523+
elif n.name != trans.name:
524+
self._g.remove_node(n.name)
525+
return True
526+
481527
def _identity_handler(self, trans, node):
482528
if node.output[0] in node.graph.outputs:
483529
return False

0 commit comments

Comments
 (0)