Skip to content

Commit 1e2f4b2

Browse files
committed
Fixes #595, support operator InvertPermutation
Signed-off-by: xavier dupré <[email protected]>
1 parent c9b4219 commit 1e2f4b2

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3676,6 +3676,19 @@ def func3(x):
36763676
with self.assertRaises(ValueError):
36773677
self._run_test_case(func3, [_OUTPUT], {_INPUT: x_val})
36783678

3679+
@check_opset_min_version(10, "topk")
3680+
def test_invert_permutation(self):
3681+
3682+
def func(x):
3683+
op_ = tf.math.invert_permutation(x)
3684+
return tf.identity(op_, name=_TFOUTPUT)
3685+
3686+
x_val = np.array([0, 1, 2, 3], dtype=np.int64)
3687+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3688+
3689+
x_val = np.array([1, 5, 2, 0, 3, 4], dtype=np.int64)
3690+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3691+
36793692

36803693
if __name__ == '__main__':
36813694
unittest_main()

tf2onnx/onnx_opset/math.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,3 +697,38 @@ def atan2(y, x):
697697
op_name_scope=node.name + 'all',
698698
shapes=[shape], dtypes=[onnx_dtype])
699699
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
700+
701+
702+
@tf_op("InvertPermutation")
703+
class InvertPermutationOp:
704+
supported_dtypes = [
705+
onnx_pb.TensorProto.INT32,
706+
onnx_pb.TensorProto.INT64,
707+
]
708+
709+
@classmethod
710+
def version_10(cls, ctx, node, **kwargs):
711+
712+
onnx_dtype = ctx.get_dtype(node.input[0])
713+
shape = ctx.get_shape(node.input[0])
714+
715+
shape_node = ctx.make_node(
716+
"Shape", inputs=node.input, name=utils.make_name(node.name + '_shape'))
717+
718+
neg_node = ctx.make_node(
719+
"Neg", inputs=node.input, name=utils.make_name(node.name + '_neg'))
720+
721+
topk_unused = utils.make_name(node.name + '_topk')
722+
topk_indices = utils.make_name(node.name + '_indices')
723+
outputs = [topk_unused, utils.port_name(topk_indices, 1)]
724+
topk_node = ctx.make_node(
725+
"TopK", inputs=[neg_node.output[0], shape_node.output[0]],
726+
name=utils.make_name(node.name + '_topk'), outputs=outputs)
727+
728+
ctx.remove_node(node.name)
729+
730+
last_node = ctx.make_node(
731+
"Identity", inputs=topk_node.output[1:], name=utils.make_name(node.name + '_indices'),
732+
shapes=[shape], dtypes=[onnx_dtype])
733+
734+
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()

0 commit comments

Comments
 (0)