From c3c71d65d5c8fdb5a7fafe17996280f3473f0cf0 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Wed, 10 Sep 2025 12:01:38 -0700 Subject: [PATCH] Fix XNNPACK handling of negative permute dims --- backends/xnnpack/operators/op_permute.py | 9 ++++++++- backends/xnnpack/test/ops/test_permute.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/backends/xnnpack/operators/op_permute.py b/backends/xnnpack/operators/op_permute.py index 4d62d457cd0..d4086ed68b7 100644 --- a/backends/xnnpack/operators/op_permute.py +++ b/backends/xnnpack/operators/op_permute.py @@ -44,14 +44,21 @@ def define_node( self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) # input - input_id = vals_to_ids[get_input_node(node, 0)] + input_node = get_input_node(node, 0) + input_id = vals_to_ids[input_node] # output output_id = vals_to_ids[node] # permutation + input_rank = input_node.meta["val"].dim() permute_order = cast(List[int], node.args[1]) + # Handle negative dimensions by converting them to positive indices + permute_order = [ + (dim + input_rank) if dim < 0 else dim for dim in permute_order + ] + # change permute order if under channels last is_channels_last = node.meta.get( ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False diff --git a/backends/xnnpack/test/ops/test_permute.py b/backends/xnnpack/test/ops/test_permute.py index 2991ba1773d..e0171f2178f 100644 --- a/backends/xnnpack/test/ops/test_permute.py +++ b/backends/xnnpack/test/ops/test_permute.py @@ -55,6 +55,20 @@ def test_fp32_permute(self): inputs = (torch.randn(1, 1, 4, 4),) self._test_permute(inputs) + def test_fp32_permute_negative_dim(self): + inputs = (torch.randn(1, 1, 4, 4),) + ( + Tester(self.Permute([0, -2, -1, 1]), inputs) + .export() + .check_count({"torch.ops.aten.permute.default": 1}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_permute_copy_default"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + def test_fp32_permute_copy(self): inputs = (torch.randn(1, 1, 4, 4),) (