Skip to content

Commit e58a087

Browse files
Bug fixes for opset 13 support
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 7ef484c commit e58a087

File tree

8 files changed

+76
-38
lines changed

8 files changed

+76
-38
lines changed

tests/test_optimizers.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
from onnx import helper, TensorProto, OperatorSetIdProto
1313
from backend_test_base import Tf2OnnxBackendTestBase
14-
from common import unittest_main, group_nodes_by_type, check_opset_min_version, check_opset_max_version
14+
from common import unittest_main, group_nodes_by_type, check_opset_min_version, check_opset_max_version, get_test_config
1515
from tf2onnx import utils, constants
1616
from tf2onnx.graph import GraphUtil
1717

@@ -471,12 +471,23 @@ def _define_loop_graph(external_inputs):
471471
node1 = helper.make_node("Gather", [external_inputs[0], "loop_iter_num"], ["Y0"])
472472
node2 = helper.make_node("Transpose", ["Y0"], ["Z0"], perm=[0, 2, 3, 1])
473473
# graph output
474-
node3 = helper.make_node("Squeeze", ["Z0"], ["scan_output"], axes=[0])
474+
if get_test_config().opset <= 12:
475+
node3 = helper.make_node("Squeeze", ["Z0"], ["scan_output"], axes=[0])
476+
const_node = None
477+
else:
478+
const_tensor = helper.make_tensor(name='const', data_type=TensorProto.INT64, dims=[1],
479+
vals=np.array([0], dtype=np.int64))
480+
const_node = helper.make_node("Constant", [], ["axes_const"], value=const_tensor, name="const")
481+
node3 = helper.make_node("Squeeze", ["Z0", "axes_const"], ["scan_output"])
475482
node4 = helper.make_node("Identity", ["loop_condition"], ["loop_cond_output"])
476483
node5 = helper.make_node("Identity", ["loop_condition"], ["loop_carried_output"])
477484

485+
nodes = [node1, node2, node3, node4, node5]
486+
if const_node is not None:
487+
nodes.append(const_node)
488+
478489
graph = helper.make_graph(
479-
[node1, node2, node3, node4, node5],
490+
nodes,
480491
"loop_subgraph",
481492
[helper.make_tensor_value_info("loop_iter_num", TensorProto.INT64, (1,)), # iteration_num
482493
helper.make_tensor_value_info("loop_condition", TensorProto.BOOL, ()), # condition
@@ -973,9 +984,9 @@ def test_duplicated_duplicated_input(self):
973984

974985
def test_duplicated_duplicated_attributes(self):
975986
# same attr or not
976-
node0 = helper.make_node('ReduceSum', inputs=["X"], outputs=["value0"], axes=[0], keepdims=0)
977-
node1 = helper.make_node('ReduceSum', inputs=["X"], outputs=["value1"], axes=[0], keepdims=0)
978-
node2 = helper.make_node('ReduceSum', inputs=["X"], outputs=["value2"], axes=[1], keepdims=0)
987+
node0 = helper.make_node('ReduceMin', inputs=["X"], outputs=["value0"], axes=[0], keepdims=0)
988+
node1 = helper.make_node('ReduceMin', inputs=["X"], outputs=["value1"], axes=[0], keepdims=0)
989+
node2 = helper.make_node('ReduceMin', inputs=["X"], outputs=["value2"], axes=[1], keepdims=0)
979990
node3 = helper.make_node('Add', inputs=["value0", "value1"], outputs=["value3"])
980991
node4 = helper.make_node("Mul", ["value2", "value3"], ["OUT"])
981992

@@ -988,7 +999,7 @@ def test_duplicated_duplicated_attributes(self):
988999

9891000
model_proto = self.make_model(graph, producer_name="onnx-tests")
9901001
self.run_merge_duplicated_nodes_compare(["OUT"], {"X": np.random.randn(5, 5).astype(np.float32)}, model_proto,
991-
op_type="ReduceSum", remaining_op_num=2)
1002+
op_type="ReduceMin", remaining_op_num=2)
9921003

9931004
def _check_initializer_num(self, graph_proto, num):
9941005
return num == len(graph_proto.initializer)

tf2onnx/graph_builder.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None, return_node=Fa
5555
inputs = [data, starts, ends, axes, steps]
5656

5757
# pro-process inputs and attr
58-
if kwargs:
59-
logger.warning("kwargs contains un-used key")
58+
utils.make_sure(not kwargs, "kwargs contains un-used key")
6059

6160
new_attr = {}
6261
for key, val in attr.items():
@@ -107,8 +106,7 @@ def make_reduce_sum(self, kwargs, name=None, shapes=None, dtypes=None):
107106
attr = {"keepdims": keepdims, "noop_with_empty_axes": noop_with_empty_axes}
108107
inputs = [data, axes]
109108

110-
if kwargs:
111-
logger.warning("kwargs contains un-used key")
109+
utils.make_sure(not kwargs, "kwargs contains un-used key")
112110

113111
new_attr = {}
114112
for key, val in attr.items():
@@ -137,8 +135,7 @@ def make_squeeze(self, kwargs, name=None, shapes=None, dtypes=None, return_node=
137135
attr = {}
138136
inputs = [data, axes]
139137

140-
if kwargs:
141-
logger.warning("kwargs contains un-used key")
138+
utils.make_sure(not kwargs, "kwargs contains un-used key")
142139

143140
new_attr = {}
144141
for key, val in attr.items():
@@ -178,8 +175,7 @@ def make_unsqueeze(self, kwargs, name=None, shapes=None, dtypes=None, return_nod
178175
attr = {}
179176
inputs = [data, axes]
180177

181-
if kwargs:
182-
logger.warning("kwargs contains un-used key")
178+
utils.make_sure(not kwargs, "kwargs contains un-used key")
183179

184180
new_attr = {}
185181
for key, val in attr.items():

tf2onnx/onnx_opset/controlflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,10 @@ def version_13(cls, ctx, node, **kwargs):
298298

299299
g = GraphBuilder(ctx)
300300

301-
usq_node = g.make_unsqueeze({"axes": [0], 'name': node.child_name(), 'data': node.input[1]}, return_node=True)
301+
usq_node = g.make_unsqueeze({"axes": [0], 'data': node.input[1]}, name=node.child_name(), return_node=True)
302302
ctx.insert_node_on_output(usq_node)
303303

304-
sq_node = g.make_squeeze({"axes": [0], 'name': node.child_name(), 'data': node.output[0]}, return_node=True)
304+
sq_node = g.make_squeeze({"axes": [0], 'data': node.output[0]}, name=node.child_name(), return_node=True)
305305
ctx.insert_node_on_output(sq_node)
306306

307307

tf2onnx/onnx_opset/tensor.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -805,8 +805,10 @@ def version_1(cls, ctx, node, **kwargs):
805805
# insert_new_node_on_output(self, op_type, output_name=None, name=None, inputs=None, domain=None, **kwargs)
806806
# ctx.insert_new_node_on_output("Squeeze", node.output[0], name)
807807
name = utils.make_name(node.name)
808-
squeeze_node = ctx.insert_new_node_on_output("Squeeze", node.output[0], name)
809-
squeeze_node.set_attr("axes", needs_squeeze)
808+
squeeze_node = GraphBuilder(ctx).make_squeeze(
809+
{"axes": needs_squeeze, 'data': node.output[0]}, name=name, return_node=True)
810+
ctx.insert_node_on_output(squeeze_node)
811+
810812
nodes.append(squeeze_node)
811813
input_dtype = ctx.get_dtype(node.output[0])
812814
ctx.set_dtype(squeeze_node.output[0], input_dtype)
@@ -1023,8 +1025,9 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
10231025
node = GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=out_dtypes, shapes=out_shapes)
10241026
node = ctx.get_node_by_output(node)
10251027
if needs_squeeze:
1026-
squeeze_node = ctx.insert_new_node_on_output("Squeeze", node.output[0], node.child_name())
1027-
squeeze_node.set_attr("axes", needs_squeeze)
1028+
squeeze_node = GraphBuilder(ctx).make_squeeze(
1029+
{"axes": needs_squeeze, "data": node.output[0]}, name=node.child_name(), return_node=True)
1030+
ctx.insert_node_on_output(squeeze_node, node.output[0])
10281031
input_dtype = ctx.get_dtype(node.output[0])
10291032
ctx.set_dtype(squeeze_node.output[0], input_dtype)
10301033
ctx.copy_shape(node.output[0], squeeze_node.output[0])
@@ -1348,7 +1351,8 @@ def any_version(cls, opset, ctx, node, **kwargs):
13481351
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
13491352
if len(input_shape) == 3:
13501353
# insert automatically an Unsqueeze op if the input is 3d
1351-
unsqz1 = ctx.make_node("Unsqueeze", input_tensor.output, {"axes": [3]})
1354+
unsqz1 = GraphBuilder(ctx).make_unsqueeze(
1355+
{"axes": [3], "data": input_tensor.output[0]}, return_node=True)
13521356
trans1 = ctx.make_node("Transpose", unsqz1.output, {"perm": [3, 0, 1, 2]})
13531357
else:
13541358
trans1 = ctx.make_node("Transpose", input_tensor.output, {"perm": [3, 0, 1, 2]})
@@ -1377,22 +1381,21 @@ def any_version(cls, opset, ctx, node, **kwargs):
13771381
kwargs = {**inputs_map}
13781382
ctx.remove_node(node.name)
13791383
slice1 = GraphBuilder(ctx).make_slice(kwargs)
1380-
ctx.make_node("Squeeze", [slice1], {"axes": [3]},
1381-
outputs=node.output, name=node.name, dtypes=dtypes, shapes=shapes)
1384+
GraphBuilder(ctx).make_squeeze(
1385+
{"axes": [3], "data": slice1, "outputs": node.output}, name=node.name, dtypes=dtypes, shapes=shapes)
13821386
else:
13831387
kwargs = {**inputs_map, "outputs": node.output}
13841388
ctx.remove_node(node.name)
13851389
GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=dtypes, shapes=shapes)
13861390
else:
13871391
def mknode(optype, inputs, attrs=None):
13881392
nodename = utils.make_name(node.name + '_' + optype.lower())
1389-
if opset < 13 or optype != 'Squeeze':
1390-
return ctx.make_node(optype, inputs, attrs, name=nodename)
1391-
inputs.append(attrs['axes'])
1392-
attrs = attrs.copy()
1393-
attrs.pop('axes')
1393+
if opset >= 13 and optype == 'Squeeze':
1394+
return GraphBuilder(ctx).make_squeeze(
1395+
{"axes": attrs['axes'], "data": inputs[0]}, name=nodename, return_node=True)
13941396
return ctx.make_node(optype, inputs, attrs, name=nodename)
13951397

1398+
13961399
# support non 3D/4D tensors and dynamic crop vals
13971400
# dynamic slice starts at opset 10
13981401
utils.make_sure(ctx.opset >= 11, 'non-4D tensor or non-const crops require opset 11')

tf2onnx/optimizer/back_to_back_optimizer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,20 @@ def _optimize_squeeze_unsqueeze(g, node, consumer_nodes):
161161
if node2.type != 'Unsqueeze':
162162
return []
163163

164-
axis1 = node.get_attr('axes').ints
165-
axis2 = node2.get_attr('axes').ints
164+
axes_match = False
165+
if g.opset <= 12 and node.get_attr('axes').ints == node2.get_attr('axes').ints:
166+
axes_match = True
167+
168+
# In opset 13, axes is an input. Optional for squeeze op.
169+
if g.opset >= 13 and len(node.input) == 2:
170+
if node.input[1] == node2.input[1]:
171+
axes_match = True
172+
elif node.inputs[1].is_const() and node2.inputs[1].is_const() and \
173+
node.inputs[1].get_tensor_value(as_list=True) == node2.inputs[1].get_tensor_value(as_list=True):
174+
axes_match = True
166175

167176
# if squeeze followed by unsqueeze is on diff axes, skip
168-
if axis1 != axis2:
177+
if not axes_match:
169178
return []
170179

171180
# if unsqueeze output is graph output, skip

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,13 @@ def _fold_unsqueeze(node, graph):
130130
numpy expand_dims only supports to unsqueeze one dim one time, so reshape is used to simplify the logic
131131
"""
132132
const_val = node.inputs[0].get_tensor_value(as_list=False)
133-
axes = list(node.get_attr("axes").ints)
134-
utils.make_sure(all(axis >= 0 for axis in axes), "onnx spec says it only supports positive axis")
133+
if graph.opset >= 13:
134+
axes = node.inputs[1].get_tensor_value(as_list=True)
135+
else:
136+
axes = list(node.get_attr("axes").ints)
135137
shape_in = const_val.shape
136138
dims_out = len(shape_in) + len(axes)
139+
axes = [i if i >= 0 else i + dims_out for i in axes]
137140
# calculate the shape of output accroding to onnx Unsqueeze's spec
138141
# https://github.com/onnx/onnx/blob/master/docs/Operators.md#Unsqueeze
139142
shape_in = iter(shape_in)

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,18 +587,30 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
587587
if not self._nodes_has_single_consumer_node([trans]):
588588
return False
589589

590+
axes = None
591+
# in opset 13, axes is an input not attr
590592
if node.get_attr("axes"):
593+
axes = node.get_attr("axes").ints
594+
if len(node.input) > 1 and node.inputs[1].is_const():
595+
axes = node.inputs[1].get_tensor_value(as_list=True)
596+
597+
if axes is not None:
591598
# switch tran and squeeze
592599
# 1 switch
593600
self._g.replace_all_inputs(node.output[0], trans.output[0]) # ops=self._g.get_nodes()
594601
self._g.replace_input(node, node.input[0], trans.input[0], 0)
595602
self._g.replace_input(trans, trans.input[0], node.output[0], 0)
596603
# 2 correct attr of nodes
597-
squeeze_axes = sorted(list(node.get_attr("axes").ints))
604+
squeeze_axes = sorted(axes)
598605
trans_perm = list(trans.get_attr("perm").ints)
599606
new_perm, new_squeeze_axes = _calculate_new_attr(ori_perm=trans_perm, ori_squeeze_axes=squeeze_axes)
600607
trans.set_attr("perm", new_perm)
601-
node.set_attr("axes", new_squeeze_axes)
608+
if self._g.opset <= 12:
609+
node.set_attr("axes", new_squeeze_axes)
610+
else:
611+
new_axes_np = np.array(new_squeeze_axes, dtype=np.int64)
612+
new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np)
613+
self._g.replace_inputs(node, [node.input[0], new_axes_const.output[0]])
602614
# 3 set shape
603615
squeeze_shape = self._g.get_shape(node.output[0])
604616
self._g.set_shape(trans.output[0], squeeze_shape)

tf2onnx/rewriter/eye_rewriter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
from onnx import onnx_pb
9+
from tf2onnx.graph_builder import GraphBuilder
910
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
1011

1112
# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable
@@ -134,8 +135,11 @@ def rewrite_eye(g, ops):
134135
g.remove_node(old_output.name)
135136

136137
# onnx op "EyeLike" need a 2D tensor, so generate it
137-
num_rows = g.make_node("Unsqueeze", num_rows.output, attr={"axes": [0]})
138-
num_columns = g.make_node("Unsqueeze", num_columns.output, attr={"axes": [0]})
138+
139+
num_rows = GraphBuilder(g).make_unsqueeze(
140+
{"axes": [0], "data": num_rows.output[0]}, return_node=True)
141+
num_columns = GraphBuilder(g).make_unsqueeze(
142+
{"axes": [0], "data": num_columns.output[0]}, return_node=True)
139143
matrix_shape = g.make_node("Concat", [num_rows.output[0], num_columns.output[0]], attr={"axis": 0})
140144
# cast nodes added for "ConstantOfShape" in ONNX only accepts int64 data.
141145
matrix_shape_int64 = g.make_node("Cast", matrix_shape.output, attr={"to": onnx_pb.TensorProto.INT64})

0 commit comments

Comments
 (0)