Skip to content

Commit 0083dfa

Browse files
Fix bug in topk casts affecting other input consumers (#1674)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 508f6be commit 0083dfa

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
11721172
# cast X if needed
11731173
if dtypes[0] != onnx_pb.TensorProto.FLOAT:
11741174
# opset-10 supports types other than float but onnxruntime does not
1175-
ctx.insert_new_node_on_output("Cast", node.input[0], to=onnx_pb.TensorProto.FLOAT)
1175+
ctx.insert_new_node_on_input(node, "Cast", node.input[0], input_index=0, to=onnx_pb.TensorProto.FLOAT)
11761176
ctx.insert_new_node_on_output("Cast", node.output[0], to=dtypes[0])
11771177
# cast the index output to int32
11781178
cast_out = ctx.insert_new_node_on_output("Cast", node.output[1], name=utils.make_name(node.name), to=dtypes[1])

0 commit comments

Comments
 (0)