Skip to content

Commit e36dac2

Browse files
Merge pull request #1241 from onnx/tom/opset13updates2
More bug fixes for opset 13 support
2 parents ac2f675 + 60b4ddd commit e36dac2

File tree

4 files changed

+25
-33
lines changed

4 files changed

+25
-33
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -286,24 +286,17 @@ class TensorListGetItem:
286286
def version_7(cls, ctx, node, **kwargs):
287287
ctx.ta_reads.append(node.input[0])
288288
node.type = "Gather"
289-
ctx.replace_inputs(node, [node.input[0], node.input[1]])
290-
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], name=node.child_name(), axes=[0])
291-
ctx.insert_new_node_on_output("Squeeze", node.output[0], name=node.child_name(), axes=[0])
292-
293-
@classmethod
294-
def version_13(cls, ctx, node, **kwargs):
295-
ctx.ta_reads.append(node.input[0])
296-
node.type = "Gather"
297-
ctx.replace_inputs(node, [node.input[0], node.input[1]])
298-
299289
g = GraphBuilder(ctx)
300290

301291
usq_node = g.make_unsqueeze({"axes": [0], 'data': node.input[1]}, name=node.child_name(), return_node=True)
302-
ctx.insert_node_on_output(usq_node)
303-
292+
ctx.replace_inputs(node, [node.input[0], usq_node.output[0]])
304293
sq_node = g.make_squeeze({"axes": [0], 'data': node.output[0]}, name=node.child_name(), return_node=True)
305294
ctx.insert_node_on_output(sq_node)
306295

296+
@classmethod
297+
def version_13(cls, ctx, node, **kwargs):
298+
cls.version_7(ctx, node, **kwargs)
299+
307300

308301
@tf_op(["TensorListLength"])
309302
class TensorListLength:

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def any_version(cls, opset, ctx, node, **kwargs):
353353
if len(input_shape) == spatial + 1:
354354
gb = GraphBuilder(ctx)
355355
usq_node = gb.make_unsqueeze({"axes": [0], 'data': node.input[0]}, return_node=True)
356-
ctx.insert_node_on_output(usq_node, node.input[0])
356+
ctx.replace_inputs(node, [usq_node.output[0]] + node.input[1:])
357357

358358
# Set padding.
359359
add_padding(

tf2onnx/onnx_opset/tensor.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,8 +1256,8 @@ def any_version_after9(cls, opset, ctx, node, **kwargs):
12561256
depth = GraphBuilder(ctx).make_unsqueeze({'data': node.input[1], 'axes': [0]})
12571257
on_value = node.input[2]
12581258
off_value = node.input[3]
1259-
on_value = ctx.make_node("Unsqueeze", [on_value], attr={"axes": [0]}).output[0]
1260-
off_value = ctx.make_node("Unsqueeze", [off_value], attr={"axes": [0]}).output[0]
1259+
on_value = GraphBuilder(ctx).make_unsqueeze({'data': on_value, 'axes': [0]})
1260+
off_value = GraphBuilder(ctx).make_unsqueeze({'data': off_value, 'axes': [0]})
12611261
off_on_value = ctx.make_node("Concat", [off_value, on_value], attr={"axis": 0}).output[0]
12621262

12631263
indices = node.input[0]
@@ -2385,15 +2385,17 @@ def normalize():
23852385
pad_length_2 = body_graph.make_node('Concat', [zeo, pad_length.output[0]], attr={'axis': 0})
23862386
padded_range = body_graph.make_node('Pad', [sliced_range.output[0], pad_length_2.output[0]])
23872387
# opset == 11, no need to change unsqueeze
2388-
unsqueezed_range = body_graph.make_node('Unsqueeze', [padded_range.output[0]], attr={'axes': [1]})
2388+
unsqueezed_range = GraphBuilder(body_graph).make_unsqueeze(
2389+
{'data': padded_range.output[0], 'axes': [1]}, return_node=True)
23892390
half_shape_x = body_graph.make_node('Slice',
23902391
[new_shape.output[0], zeo, minus_two])
23912392
shape_range = body_graph.make_node('Shape', [unsqueezed_range.output[0]])
23922393
full_shape = body_graph.make_node('Concat', [half_shape_x.output[0], shape_range.output[0]], attr={'axis': 0})
23932394
expanded_range = body_graph.make_node('Expand', [unsqueezed_range.output[0], full_shape.output[0]])
23942395
gathered_input = body_graph.make_node('GatherElements', [processed_input.output[0], expanded_range.output[0]],
23952396
attr={'axis': -1})
2396-
squeezed_input = body_graph.make_node('Squeeze', [gathered_input.output[0]], attr={'axes': [-1]})
2397+
squeezed_input = GraphBuilder(body_graph).make_squeeze(
2398+
{'data': gathered_input.output[0], 'axes': [-1]}, return_node=True)
23972399
left_width = body_graph.make_node('Sub', [new_width.output[0], abs_k.output[0]])
23982400
dims = body_graph.make_node('Concat', [left_width.output[0], new_depth.output[0]], attr={'axis': 0})
23992401
valid_dim = body_graph.make_node('ReduceMin', [dims.output[0]])
@@ -2505,8 +2507,8 @@ def normalize():
25052507
raw_output_shape + [-1])
25062508
squeeze_sliced_graph = ctx.create_new_graph_with_same_config()
25072509
squeeze_sliced_graph.parent_graph = ctx
2508-
squeeze_sliced = squeeze_sliced_graph.make_node('Squeeze', [final_output_right_sliced.output[0]],
2509-
attr={'axes': [-2]})
2510+
squeeze_sliced = GraphBuilder(squeeze_sliced_graph).make_squeeze(
2511+
{'data': final_output_right_sliced.output[0], 'axes': [-2]}, return_node=True)
25102512
squeeze_sliced_graph.add_graph_output(squeeze_sliced.output[0], ctx.get_dtype(node.input[0]), raw_output_shape)
25112513
shapes = node.output_shapes
25122514
dtypes = node.output_dtypes
@@ -2680,14 +2682,14 @@ def version_13(cls, ctx, node, **kwargs):
26802682
@tf_op(["MatrixDiag", "MatrixDiagV2", "MatrixDiagV3"])
26812683
class MatrixDiag:
26822684
@classmethod
2683-
def any_version(cls, opset, ctx, node, **kwargs):
2685+
def version_12(cls, ctx, node, **kwargs):
26842686
# Assemble MatrixDiagV3 by ReverseSequence
26852687
argc = len(node.input)
26862688

2687-
if opset >= 13:
2688-
squeeze_axes0 = ctx.make_const(utils.make_name("const_axes"), np.array([0], dtype=np.int64))
2689-
squeeze_axes_1 = ctx.make_const(utils.make_name("const_axes"), np.array([-1], dtype=np.int64))
2690-
squeeze_axes_2 = ctx.make_const(utils.make_name("const_axes"), np.array([-2], dtype=np.int64))
2689+
if ctx.opset >= 13:
2690+
squeeze_axes0 = ctx.make_const(utils.make_name("const_axes"), np.array([0], dtype=np.int64)).output[0]
2691+
squeeze_axes_1 = ctx.make_const(utils.make_name("const_axes"), np.array([-1], dtype=np.int64)).output[0]
2692+
squeeze_axes_2 = ctx.make_const(utils.make_name("const_axes"), np.array([-2], dtype=np.int64)).output[0]
26912693

26922694
minus_two, minus_one, zeo, one, two = \
26932695
[n.output[0] for n in ctx.make_consts([[-2], [-1], [0], [1], [2]])]
@@ -2712,7 +2714,7 @@ def processdiag():
27122714
diag = node.input[0]
27132715
shape = ctx.get_shape(diag)
27142716
if len(shape) == 1:
2715-
if opset < 13:
2717+
if ctx.opset < 13:
27162718
diag = mknode("Unsqueeze", [diag], attr={"axes": [0]})
27172719
else:
27182720
diag = mknode("Unsqueeze", [diag, squeeze_axes0])
@@ -2737,7 +2739,7 @@ def id_diag():
27372739
def ex_diag():
27382740
g = ctx.create_new_graph_with_same_config()
27392741
g.parent_graph = ctx
2740-
if opset < 13:
2742+
if ctx.opset < 13:
27412743
ex = mknode2(g, "Unsqueeze", [diag], attr={"axes": [-2]})
27422744
else:
27432745
ex = mknode2(g, "Unsqueeze", [diag, squeeze_axes_2])
@@ -2755,7 +2757,7 @@ def squeeze_12(name):
27552757
def squeeze_13(name):
27562758
return ctx.make_node("Squeeze", [name, squeeze_axes_1]).output[0]
27572759

2758-
squeeze = squeeze_12 if opset < 13 else squeeze_13
2760+
squeeze = squeeze_12 if ctx.opset < 13 else squeeze_13
27592761

27602762
# gather inputs
27612763
diag, k, k_min, k_max, k_max_nxt = processdiag()
@@ -3018,14 +3020,10 @@ def paddiag():
30183020
ctx.make_node("Identity", [padded], name=node.name,
30193021
outputs=node.output, shapes=shapes, dtypes=dtypes)
30203022

3021-
@classmethod
3022-
def version_12(cls, ctx, node, **kwargs):
3023-
cls.any_version(12, ctx, node, **kwargs)
3024-
30253023
@classmethod
30263024
def version_13(cls, ctx, node, **kwargs):
30273025
# Parameters moved to inputs for operator Squeeze, Unsqueeze.
3028-
cls.any_version(13, ctx, node, **kwargs)
3026+
cls.version_12(ctx, node, **kwargs)
30293027

30303028

30313029
@tf_op("MatrixSetDiagV3")

tf2onnx/rewriter/gru_rewriter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ def process_var_init_nodes(self, context):
211211
const_node = self.g.make_const(initial_name, new_val)
212212
context.onnx_input_ids["initial_state"] = const_node.output[0]
213213
return
214-
squeeze_node = self.g.make_node("Unsqueeze", [initializer_input_id], attr={"axes": [0]})
214+
squeeze_node = GraphBuilder(self.g).make_unsqueeze(
215+
{'data': initializer_input_id, 'axes': [0]}, return_node=True)
215216
to_replace = [n for n in self.g.get_nodes() if n != squeeze_node]
216217
self.g.replace_all_inputs(initializer_input_id, squeeze_node.output[0], ops=to_replace)
217218
context.onnx_input_ids["initial_state"] = squeeze_node.output[0]

0 commit comments

Comments
 (0)