Skip to content

Commit 59c5fda

Browse files
committed
fix bug: when deleting node, its shapes and dtypes should be put back to
graph if necessary
1 parent 9713ffd commit 59c5fda

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,8 @@ def version_4(cls, ctx, node, **kwargs):
853853
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
854854
"only support same blocksize at different dims")
855855

856+
shapes = [ctx.get_shape(node.output[0])]
857+
dtypes = [ctx.get_dtype(node.output[0])]
856858
ctx.remove_node(node.name)
857859

858860
# implement pads logic, the data format is NHWC
@@ -866,4 +868,5 @@ def version_4(cls, ctx, node, **kwargs):
866868
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
867869
trans1 = ctx.make_node("Transpose", pad_op.output, {"perm": [3, 0, 1, 2]})
868870
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
869-
ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]}, name=node.name, outputs=node.output)
871+
ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]}, name=node.name, outputs=node.output,
872+
shapes=shapes, dtypes=dtypes)

0 commit comments

Comments
 (0)