Skip to content

Commit d2e0aaa

Browse files
Fixed bugs in squeeze/unsqueeze for string ops (#1261)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 2259c3d commit d2e0aaa

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

tf2onnx/custom_opsets/string_ops.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ def any_version(cls, opset, ctx, node, **kwargs):
2626
node.domain = constants.CONTRIB_OPS_DOMAIN
2727
for a in list(node.attr.keys()):
2828
del node.attr[a]
29-
unsqueeze_node = GraphBuilder(ctx).make_squeeze(
30-
{'data': node.input[1], 'axes': [0]}, return_node=True)
29+
unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({'data': node.input[1], 'axes': [0]}, return_node=True)
3130

3231
skip_empty_const = ctx.make_const(utils.make_name('skip_empty_const'), np.array([skip_empty], np.bool))
3332
ctx.replace_inputs(node, [node.input[0], unsqueeze_node.output[0], skip_empty_const.output[0]])
@@ -88,8 +87,8 @@ def any_version(cls, opset, ctx, node, **kwargs):
8887
if ctx.get_shape(inp) == [] and shape_node is not None:
8988
expand_node = ctx.make_node("Expand", [inp, shape_node.output[0]])
9089
inp = expand_node.output[0]
91-
unsqueeze_node = GraphBuilder(ctx).make_squeeze({'data': inp, 'axes': [0]})
92-
unsqueezes.append(unsqueeze_node.output[0])
90+
unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({'data': inp, 'axes': [0]})
91+
unsqueezes.append(unsqueeze_node)
9392
stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0})
9493
ctx.replace_inputs(node, [stack_node.output[0], separator_node.output[0], axis_node.output[0]])
9594

0 commit comments

Comments
 (0)