Skip to content

Commit bee1273

Browse files
Fix Conv groups for unknown input dims (#1552)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 963a51e commit bee1273

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,12 +361,13 @@ def any_version(cls, opset, ctx, node, **kwargs):
361361
)
362362
groups = int(1)
363363
data_format = str(node.attr["data_format"].s, encoding="utf8")
364+
shape_dim = -1
364365
if data_format == "NHWC":
365-
groups = int(ctx.get_shape(node.input[0])[3] / ctx.get_shape(node.input[1])[2])
366+
shape_dim = ctx.get_shape(node.input[0])[3]
366367
elif data_format == "NCHW":
367-
groups = int(ctx.get_shape(node.input[0])[1] / ctx.get_shape(node.input[1])[2])
368-
else:
369-
pass
368+
shape_dim = ctx.get_shape(node.input[0])[1]
369+
if shape_dim != -1:
370+
groups = int(shape_dim / ctx.get_shape(node.input[1])[2])
370371

371372
node.set_attr("group", groups)
372373

0 commit comments

Comments
 (0)