Skip to content

Commit ba533d0

Browse files
authored
Merge pull request #1117 from xadupre/ip
Fixes #595, support operator InvertPermutation
2 parents c109241 + fb9106d commit ba533d0

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3676,6 +3676,19 @@ def func3(x):
36763676
with self.assertRaises(ValueError):
36773677
self._run_test_case(func3, [_OUTPUT], {_INPUT: x_val})
36783678

3679+
@check_opset_min_version(11, "topk")
3680+
def test_invert_permutation(self):
3681+
3682+
def func(x):
3683+
op_ = tf.math.invert_permutation(x)
3684+
return tf.identity(op_, name=_TFOUTPUT)
3685+
3686+
x_val = np.array([0, 1, 2, 3], dtype=np.int64)
3687+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3688+
3689+
x_val = np.array([1, 5, 2, 0, 3, 4], dtype=np.int64)
3690+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3691+
36793692

36803693
if __name__ == '__main__':
36813694
unittest_main()

tf2onnx/onnx_opset/math.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,3 +700,34 @@ def atan2(y, x):
700700
op_name_scope=node.name + 'all',
701701
shapes=[shape], dtypes=[onnx_dtype])
702702
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
703+
704+
705+
@tf_op("InvertPermutation")
706+
class InvertPermutationOp:
707+
708+
@classmethod
709+
def version_11(cls, ctx, node, **kwargs):
710+
711+
supported_dtypes = [onnx_pb.TensorProto.INT32, onnx_pb.TensorProto.INT64]
712+
onnx_dtype = ctx.get_dtype(node.input[0])
713+
utils.make_sure(onnx_dtype in supported_dtypes, "InvertPermutation only applies on INT32, INT64.")
714+
715+
shape = ctx.get_shape(node.input[0])
716+
717+
shape_node = ctx.make_node(
718+
"Shape", inputs=node.input, name=utils.make_name(node.name + '_shape'))
719+
720+
neg_node = ctx.make_node(
721+
"Neg", inputs=node.input, name=utils.make_name(node.name + '_neg'))
722+
723+
topk_node = ctx.make_node(
724+
"TopK", inputs=[neg_node.output[0], shape_node.output[0]],
725+
name=utils.make_name(node.name + '_topk'), output_count=2)
726+
727+
ctx.remove_node(node.name)
728+
729+
last_node = ctx.make_node(
730+
"Identity", inputs=topk_node.output[1:], name=utils.make_name(node.name + '_indices'),
731+
shapes=[shape], dtypes=[onnx_dtype])
732+
733+
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()

0 commit comments

Comments
 (0)