Skip to content

Commit fb9106d

Browse files
committed
Addresses comments
Signed-off-by: xavier dupré <[email protected]>
1 parent f7aa100 commit fb9106d

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

tf2onnx/onnx_opset/math.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -704,15 +704,14 @@ def atan2(y, x):
704704

705705
@tf_op("InvertPermutation")
706706
class InvertPermutationOp:
707-
supported_dtypes = [
708-
onnx_pb.TensorProto.INT32,
709-
onnx_pb.TensorProto.INT64,
710-
]
711707

712708
@classmethod
713709
def version_11(cls, ctx, node, **kwargs):
714710

711+
supported_dtypes = [onnx_pb.TensorProto.INT32, onnx_pb.TensorProto.INT64]
715712
onnx_dtype = ctx.get_dtype(node.input[0])
713+
utils.make_sure(onnx_dtype in supported_dtypes, "InvertPermutation only applies on INT32, INT64.")
714+
716715
shape = ctx.get_shape(node.input[0])
717716

718717
shape_node = ctx.make_node(
@@ -721,12 +720,9 @@ def version_11(cls, ctx, node, **kwargs):
721720
neg_node = ctx.make_node(
722721
"Neg", inputs=node.input, name=utils.make_name(node.name + '_neg'))
723722

724-
topk_unused = utils.make_name(node.name + '_topk')
725-
topk_indices = utils.make_name(node.name + '_indices')
726-
outputs = [topk_unused, utils.port_name(topk_indices, 1)]
727723
topk_node = ctx.make_node(
728724
"TopK", inputs=[neg_node.output[0], shape_node.output[0]],
729-
name=utils.make_name(node.name + '_topk'), outputs=outputs)
725+
name=utils.make_name(node.name + '_topk'), output_count=2)
730726

731727
ctx.remove_node(node.name)
732728

0 commit comments

Comments
 (0)