Skip to content

Commit 7391004

Browse files
committed
Use unsqueeze op instead
1 parent 0d4fbb6 commit 7391004

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,6 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
204204
input_shape = ctx.get_shape(node.input[0])
205205
output_shape = ctx.get_shape(node.output[0])
206206

207-
# prefix with batch dim of [1] to satisfy rank requirements
208-
if len(input_shape) == spatial + 1:
209-
input_shape = [1] + input_shape
210-
ctx.set_shape(node.input[0], input_shape)
211-
212207
if len(input_shape) != spatial + 2:
213208
raise ValueError(
214209
"node {} output needs to be rank {}, is {}".format(
@@ -346,6 +341,11 @@ def version_1(cls, ctx, node, **kwargs):
346341
strides = conv_dims_attr(node, "strides", spatial=spatial)
347342
dilations = conv_dims_attr(node, "dilations", spatial=spatial)
348343

344+
# prefix with batch dim of [1] to satisfy rank requirements
345+
input_shape = ctx.get_shape(node.input[0])
346+
if len(input_shape) == spatial + 1:
347+
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[0], axes=[0])
348+
349349
# Set padding.
350350
add_padding(
351351
ctx, node, kernel_shape, strides, dilations=dilations, spatial=spatial

0 commit comments

Comments
 (0)