Skip to content

Commit 9940b5b

Browse files
authored
Merge pull request #88 from pengwa/fix_output_dtype
fix the graph output node dtype issue
2 parents 26cf48d + 5e6c259 commit 9940b5b

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

tf2onnx/graph.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,10 @@ def get_dtype(self, name):
300300
"""Get dtype for node."""
301301
return self._dtypes.get(name)
302302

303+
def set_dtype(self, name, val):
304+
"""Set dtype for node."""
305+
self._dtypes[name] = val
306+
303307
def get_shape(self, name):
304308
"""Get shape for node."""
305309
assert isinstance(name, str)
@@ -390,13 +394,11 @@ def make_model(self, doc, input_names, output_names, optimize=True):
390394
# create output_tensor_values
391395
output_tensor_values = []
392396
for name in output_names:
393-
op = self.get_node_by_name(name)
394-
if op:
395-
dtype = op.dtype
396-
if not dtype:
397-
continue
398-
v = helper.make_tensor_value_info(name, dtype, self.get_shape(name))
399-
output_tensor_values.append(v)
397+
dtype = self.get_dtype(name);
398+
if not dtype:
399+
raise ValueError("cannot found the output dtype for " + name)
400+
v = helper.make_tensor_value_info(name, dtype, self.get_shape(name))
401+
output_tensor_values.append(v)
400402

401403
# update attributes
402404
ops = []

tf2onnx/tfonnx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,9 @@ def topk_op(ctx, node, name, args):
898898
node.set_attr("k", k[0])
899899
node.type = "TopK"
900900
ctx.remove_input(node, node.input[1])
901+
902+
# the second of TopK operator must be INT64 per ONNX requires.
903+
ctx.set_dtype(name + ":1", onnx_pb.TensorProto.INT64)
901904
return node
902905

903906

0 commit comments

Comments
 (0)