Skip to content

Commit 38b19b2

Browse files
authored
Merge pull request #1005 from jignparm/jignparm/fixpadshape
Adjust spatial shape to meet rank requirements for Conv*D
2 parents 301fcc0 + 7391004 commit 38b19b2

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,11 @@ def version_1(cls, ctx, node, **kwargs):
341341
strides = conv_dims_attr(node, "strides", spatial=spatial)
342342
dilations = conv_dims_attr(node, "dilations", spatial=spatial)
343343

344+
# prefix with batch dim of [1] to satisfy rank requirements
345+
input_shape = ctx.get_shape(node.input[0])
346+
if len(input_shape) == spatial + 1:
347+
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[0], axes=[0])
348+
344349
# Set padding.
345350
add_padding(
346351
ctx, node, kernel_shape, strides, dilations=dilations, spatial=spatial

0 commit comments

Comments
 (0)