Skip to content

Commit 94317a3

Browse files
authored
Merge pull request #85 from onnx/gs/lstm
fix dtype bug
2 parents 5e303c4 + 4e7d14a commit 94317a3

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tf2onnx/tfonnx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,9 @@ def minmax_op(ctx, node, name, args):
911911
new_nodes = []
912912
for i in needs_broadcast_op:
913913
input_node = node.inputs[i]
914-
dtype = ctx.get_dtype(node.input[i])
914+
# we don't track dtype for inserted ops but we know the output dtype here is
915+
# the one we want.
916+
dtype = ctx.get_dtype(node.output[0])
915917
zero_name = utils.make_name(input_node.name)
916918
ctx.make_const(zero_name, "Const", np.zeros(shapeo, dtype=utils.ONNX_TO_NUMPY_DTYPE[dtype]))
917919
op_name = utils.make_name(input_node.name)

0 commit comments

Comments
 (0)