Skip to content

Commit 5e6c259

Browse files
committed
onnx cannot load int32 when top10k:1 is the graph output node
1 parent 4cf0d87 commit 5e6c259

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

tf2onnx/graph.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,10 @@ def get_dtype(self, name):
300300
"""Get dtype for node."""
301301
return self._dtypes.get(name)
302302

303+
def set_dtype(self, name, val):
304+
"""Set dtype for node."""
305+
self._dtypes[name] = val
306+
303307
def get_shape(self, name):
304308
"""Get shape for node."""
305309
assert isinstance(name, str)

tf2onnx/tfonnx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,9 @@ def topk_op(ctx, node, name, args):
884884
node.set_attr("k", k[0])
885885
node.type = "TopK"
886886
ctx.remove_input(node, node.input[1])
887+
888+
# the second of TopK operator must be INT64 per ONNX requires.
889+
ctx.set_dtype(name + ":1", onnx_pb.TensorProto.INT64)
887890
return node
888891

889892

0 commit comments

Comments
 (0)