Skip to content

Commit 8afdae5

Browse files
Fix 2 bugs to avoid extranious warnings (#1292)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent bb7df35 commit 8afdae5

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ def version_6(cls, ctx, node, **kwargs):
816816

817817
pop_var_squeezed = ctx.make_node("Div", [var_squeezed, cnt_float.output[0]]).output[0]
818818
ctx.replace_inputs(node, node.input[:3] + [avg_squeezed, pop_var_squeezed])
819-
else:
819+
elif is_training:
820820
logger.warning("Node %s of type %s has is_training set to true, which is not supperted. "
821821
"Please re-save the model with training set to false.",
822822
node.name, tf_type)

tf2onnx/onnx_opset/tensor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,10 +1194,11 @@ def version_1(cls, ctx, node, **kwargs):
11941194
# for each output we need to squeeze axis
11951195
for n in node.output:
11961196
op_name = utils.make_name(node.name)
1197-
squeeze_node = GraphBuilder(ctx).make_squeeze({'data': n, 'axes': [axis]}, name=op_name, return_node=True)
1197+
shape = ctx.get_shape(n)
1198+
dtype = ctx.get_dtype(n)
1199+
squeeze_node = GraphBuilder(ctx).make_squeeze(
1200+
{'data': n, 'axes': [axis]}, name=op_name, return_node=True, shapes=[shape], dtypes=[dtype])
11981201
ctx.insert_node_on_output(squeeze_node, n)
1199-
ctx.copy_shape(n, squeeze_node.output[0])
1200-
ctx.copy_dtype(n, squeeze_node.output[0])
12011202

12021203
# split node is 1 rank higher than squeeze nodes
12031204
output_shape = ctx.get_shape(node.output[0])

0 commit comments

Comments
 (0)