@@ -704,15 +704,14 @@ def atan2(y, x):
704
704
705
705
@tf_op ("InvertPermutation" )
706
706
class InvertPermutationOp :
707
- supported_dtypes = [
708
- onnx_pb .TensorProto .INT32 ,
709
- onnx_pb .TensorProto .INT64 ,
710
- ]
711
707
712
708
@classmethod
713
709
def version_11 (cls , ctx , node , ** kwargs ):
714
710
711
+ supported_dtypes = [onnx_pb .TensorProto .INT32 , onnx_pb .TensorProto .INT64 ]
715
712
onnx_dtype = ctx .get_dtype (node .input [0 ])
713
+ utils .make_sure (onnx_dtype in supported_dtypes , "InvertPermutation only applies on INT32, INT64." )
714
+
716
715
shape = ctx .get_shape (node .input [0 ])
717
716
718
717
shape_node = ctx .make_node (
@@ -721,12 +720,9 @@ def version_11(cls, ctx, node, **kwargs):
721
720
neg_node = ctx .make_node (
722
721
"Neg" , inputs = node .input , name = utils .make_name (node .name + '_neg' ))
723
722
724
- topk_unused = utils .make_name (node .name + '_topk' )
725
- topk_indices = utils .make_name (node .name + '_indices' )
726
- outputs = [topk_unused , utils .port_name (topk_indices , 1 )]
727
723
topk_node = ctx .make_node (
728
724
"TopK" , inputs = [neg_node .output [0 ], shape_node .output [0 ]],
729
- name = utils .make_name (node .name + '_topk' ), outputs = outputs )
725
+ name = utils .make_name (node .name + '_topk' ), output_count = 2 )
730
726
731
727
ctx .remove_node (node .name )
732
728
0 commit comments