From 5866ca7a1b36be96c24cf081982360795e11173f Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Tue, 4 Feb 2025 11:14:51 +0000 Subject: [PATCH] Squeeze on the dimension we have selected - If we squeeze using the input rank and batch_size==1 this dimension will also be removed. --- backends/arm/_passes/decompose_select.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backends/arm/_passes/decompose_select.py b/backends/arm/_passes/decompose_select.py index 9ea836e6336..5e04668df9a 100644 --- a/backends/arm/_passes/decompose_select.py +++ b/backends/arm/_passes/decompose_select.py @@ -37,14 +37,13 @@ def call(self, graph_module: torch.fx.GraphModule): rank = len(input_node.meta["val"].size()) dim = dim % rank if dim < 0 else dim index = index % rank if index < 0 else index - dim_list = list(range(rank)) with graph_module.graph.inserting_before(node): slice_node = create_node( graph_module.graph, slice_op, (input_node, dim, index, index + 1) ) squeeze_node = create_node( - graph_module.graph, squeeze_op, (slice_node, dim_list) + graph_module.graph, squeeze_op, (slice_node, [dim]) ) node.replace_all_uses_with(squeeze_node)