@@ -697,3 +697,38 @@ def atan2(y, x):
697
697
op_name_scope = node .name + 'all' ,
698
698
shapes = [shape ], dtypes = [onnx_dtype ])
699
699
ctx .replace_all_inputs (node .output [0 ], last_node .output [0 ]) # ops=ctx.get_nodes()
700
+
701
+
702
+ @tf_op ("InvertPermutation" )
703
+ class InvertPermutationOp :
704
+ supported_dtypes = [
705
+ onnx_pb .TensorProto .INT32 ,
706
+ onnx_pb .TensorProto .INT64 ,
707
+ ]
708
+
709
+ @classmethod
710
+ def version_10 (cls , ctx , node , ** kwargs ):
711
+
712
+ onnx_dtype = ctx .get_dtype (node .input [0 ])
713
+ shape = ctx .get_shape (node .input [0 ])
714
+
715
+ shape_node = ctx .make_node (
716
+ "Shape" , inputs = node .input , name = utils .make_name (node .name + '_shape' ))
717
+
718
+ neg_node = ctx .make_node (
719
+ "Neg" , inputs = node .input , name = utils .make_name (node .name + '_neg' ))
720
+
721
+ topk_unused = utils .make_name (node .name + '_topk' )
722
+ topk_indices = utils .make_name (node .name + '_indices' )
723
+ outputs = [topk_unused , utils .port_name (topk_indices , 1 )]
724
+ topk_node = ctx .make_node (
725
+ "TopK" , inputs = [neg_node .output [0 ], shape_node .output [0 ]],
726
+ name = utils .make_name (node .name + '_topk' ), outputs = outputs )
727
+
728
+ ctx .remove_node (node .name )
729
+
730
+ last_node = ctx .make_node (
731
+ "Identity" , inputs = topk_node .output [1 :], name = utils .make_name (node .name + '_indices' ),
732
+ shapes = [shape ], dtypes = [onnx_dtype ])
733
+
734
+ ctx .replace_all_inputs (node .output [0 ], last_node .output [0 ]) # ops=ctx.get_nodes()
0 commit comments