Skip to content

Commit aecd66e

Browse files
committed
removed train outputs from bn if they are not connected
1 parent 54b5b90 commit aecd66e

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tf2onnx/tfonnx.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,6 +1332,13 @@ def fused_batchnorm_op7(ctx, node, name, args):
13321332
# a: data_format, epsilon, is_training
13331333
# onnx inputs: X, scale, B, mean, variance, attributes: epsilon, momentum=0.9, spatial : 1
13341334
# output: y, mean, var, savedmean, savedvar,
1335+
1336+
# detach unused outputs. While we could let the unused outputs dangle,
1337+
# some runtimes like pytorch/caffe2 do complain about it.
1338+
consumers = [ctx.find_output_consumers(output_name) for output_name in node.output[1:]]
1339+
if not any(consumers):
1340+
del node.output[1:]
1341+
13351342
nodes = conv_convert_inputs(ctx, node, with_kernel=False)
13361343
scale_shape = ctx.get_shape(node.input[1])
13371344
mean_shape = ctx.get_shape(node.input[3])

0 commit comments

Comments
 (0)