Skip to content

Commit 94ad674

Browse files
Remove double reshape (#1495)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 5330854 commit 94ad674

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

tests/test_backend.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,17 @@ def func(x):
13411341
return tf.identity(x_, name=_TFOUTPUT)
13421342
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, check_shape=True)
13431343

1344+
def test_reshape_reshape(self):
1345+
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2))
1346+
def func(x):
1347+
shape = tf.constant([1, 4])
1348+
shape_2 = tf.constant([4, 1])
1349+
x_ = tf.reshape(x, shape)
1350+
x_ = tf.reshape(x_, shape_2)
1351+
return tf.identity(x_, name=_TFOUTPUT)
1352+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val},
1353+
graph_validator=lambda g: check_op_count(g, "Reshape", 1, disabled=False))
1354+
13441355
@check_opset_min_version(6, "cast")
13451356
def test_reshape_int(self):
13461357
x_val = np.array([1, 2, 3, 4], dtype=np.int32).reshape((2, 2))

tf2onnx/optimizer/back_to_back_optimizer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818

1919
def _register_func(op_type):
20+
if not isinstance(op_type, tuple):
21+
op_type = (op_type,)
2022
def _internal_fun(func):
2123
_func_map[op_type] = func
2224
return func
@@ -252,3 +254,18 @@ def _optimize_conv_batchnorm_fusion(g, node, consumer_nodes):
252254
g.set_shape(node2_output[0], node2_shape)
253255
g.set_dtype(node2_output[0], node2_dtype)
254256
return []
257+
258+
@staticmethod
259+
@_register_func('Reshape')
260+
def _optimize_reshape_reshape(g, node, consumer_nodes):
261+
"""remove sequential reshape nodes"""
262+
if node.type != 'Reshape' or len(consumer_nodes) != 1:
263+
return []
264+
265+
node2 = consumer_nodes[0]
266+
if node2.type != 'Reshape':
267+
return []
268+
269+
g.replace_inputs(node2, [node.input[0], node2.input[1]])
270+
g.remove_node(node.name)
271+
return []

0 commit comments

Comments
 (0)