@@ -700,3 +700,34 @@ def atan2(y, x):
700
700
op_name_scope = node .name + 'all' ,
701
701
shapes = [shape ], dtypes = [onnx_dtype ])
702
702
ctx .replace_all_inputs (node .output [0 ], last_node .output [0 ]) # ops=ctx.get_nodes()
703
+
704
+
705
+ @tf_op ("InvertPermutation" )
706
+ class InvertPermutationOp :
707
+
708
+ @classmethod
709
+ def version_11 (cls , ctx , node , ** kwargs ):
710
+
711
+ supported_dtypes = [onnx_pb .TensorProto .INT32 , onnx_pb .TensorProto .INT64 ]
712
+ onnx_dtype = ctx .get_dtype (node .input [0 ])
713
+ utils .make_sure (onnx_dtype in supported_dtypes , "InvertPermutation only applies on INT32, INT64." )
714
+
715
+ shape = ctx .get_shape (node .input [0 ])
716
+
717
+ shape_node = ctx .make_node (
718
+ "Shape" , inputs = node .input , name = utils .make_name (node .name + '_shape' ))
719
+
720
+ neg_node = ctx .make_node (
721
+ "Neg" , inputs = node .input , name = utils .make_name (node .name + '_neg' ))
722
+
723
+ topk_node = ctx .make_node (
724
+ "TopK" , inputs = [neg_node .output [0 ], shape_node .output [0 ]],
725
+ name = utils .make_name (node .name + '_topk' ), output_count = 2 )
726
+
727
+ ctx .remove_node (node .name )
728
+
729
+ last_node = ctx .make_node (
730
+ "Identity" , inputs = topk_node .output [1 :], name = utils .make_name (node .name + '_indices' ),
731
+ shapes = [shape ], dtypes = [onnx_dtype ])
732
+
733
+ ctx .replace_all_inputs (node .output [0 ], last_node .output [0 ]) # ops=ctx.get_nodes()
0 commit comments