Skip to content

Commit 0d4fbb6

Browse files
committed
Adjust spatial shape to meet rank requirements for Conv*D
1 parent 301fcc0 commit 0d4fbb6

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
204204
input_shape = ctx.get_shape(node.input[0])
205205
output_shape = ctx.get_shape(node.output[0])
206206

207+
# prefix with batch dim of [1] to satisfy rank requirements
208+
if len(input_shape) == spatial + 1:
209+
input_shape = [1] + input_shape
210+
ctx.set_shape(node.input[0], input_shape)
211+
207212
if len(input_shape) != spatial + 2:
208213
raise ValueError(
209214
"node {} output needs to be rank {}, is {}".format(

0 commit comments

Comments
 (0)