From dfb7a6f3ca8b21bfd6a5dfafb06ef1b852a29799 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Wed, 27 Nov 2024 11:30:56 +0100 Subject: [PATCH] Handle new dims for repeat in pass In addition to moving logic from node visitor, this also fixes repeating a rank 3 tensor to make a rank 3 tensor. Signed-off-by: Erik Lundell Change-Id: I7090159bce47b6aa4d6613bbeb2d681d5cfcb193 --- backends/arm/_passes/arm_pass_manager.py | 4 + .../_passes/unsqueeze_before_repeat_pass.py | 62 ++++++++++++++++ backends/arm/operators/op_repeat.py | 31 +------- backends/arm/test/ops/test_repeat.py | 11 ++- .../test_unsqueeze_before_repeat_pass.py | 74 +++++++++++++++++++ 5 files changed, 151 insertions(+), 31 deletions(-) create mode 100644 backends/arm/_passes/unsqueeze_before_repeat_pass.py create mode 100644 backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 1e2b26ef645..25811d077bb 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -41,6 +41,9 @@ ScalarsToAttributePass, ) from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass +from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import ( + UnsqueezeBeforeRepeatPass, +) from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import ( UnsqueezeScalarPlaceholdersPass, ) @@ -66,6 +69,7 @@ def transform_to_backend_pipeline( self.add_pass(RemoveClonePass()) self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(DecomposeLayerNormPass()) + self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(DecomposeVarPass()) self.add_pass(ConvertMeanDimToAveragePool()) self.add_pass(DecomposeMeanDimPass()) diff --git a/backends/arm/_passes/unsqueeze_before_repeat_pass.py b/backends/arm/_passes/unsqueeze_before_repeat_pass.py new file mode 100644 index 00000000000..01983baa9ab --- /dev/null +++ b/backends/arm/_passes/unsqueeze_before_repeat_pass.py @@ -0,0 +1,62 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# pyre-unsafe +import torch +import torch.fx +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class UnsqueezeBeforeRepeatPass(ExportPass): + """ + A TOSA TILE op only supports rank(in) == rank(out). + To support Pytorch's repeat which can also add dimensions, + we add an explicit view op before which adds the new dimensions. + New dimensions are appendend at the front, see + https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html + + Original: + repeat(multiples) + After pass: + view(shape = [1]*num_new_dims + old_shape) + repeat(multiples) + """ + + def call(self, graph_module: torch.fx.GraphModule): + modified_graph = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target != exir_ops.edge.aten.repeat.default: + continue + + old_shape = list(get_first_fake_tensor(node.all_input_nodes[0]).shape) + old_rank = len(old_shape) + multiples = node.args[1] + new_rank = len(multiples) + if old_rank == new_rank: + continue + + num_new_dims = new_rank - old_rank + new_shape = [1] * num_new_dims + old_shape + + with graph_module.graph.inserting_before(node): + view_node = create_node( + graph_module.graph, + exir_ops.edge.aten.view_copy.default, + (node.all_input_nodes[0], new_shape), + ) + node.replace_input_with(node.all_input_nodes[0], view_node) + modified_graph = True + + if modified_graph: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified_graph) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 20de9e0846a..1e4dc4e23cd 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -32,37 +32,8 @@ def define_node( is_quant_node: bool, ) -> None: - item_name = inputs[0].name - shape = inputs[0].shape - rank = len(shape) multiples = inputs[1].special - new_rank = len(multiples) - - assert new_rank >= rank - - # TILE only supports rank(in) == rank(out). To add more dims, we need a reshape first. - if new_rank > rank: - # Add length 1 dimensions to shape to match multiples - num_new_dims = new_rank - rank - expanded_shape = tuple( - 1 if i < num_new_dims else shape[i - num_new_dims] - for i in range(new_rank) - ) - expanded_shape = tosa_shape(expanded_shape, output.dim_order) - dtype = ( - ts.dtype_str_to_val("INT8") - if is_quant_node - else ts.dtype_str_to_val("FP32") - ) - - rescale_out = tosa_graph.addIntermediate(expanded_shape, dtype) - rescale_attr = ts.TosaSerializerAttribute() - rescale_attr.ReshapeAttribute(expanded_shape) - tosa_graph.addOperator( - TosaOp.Op().RESHAPE, [item_name], [rescale_out.name], rescale_attr - ) - item_name = rescale_out.name attr = ts.TosaSerializerAttribute() attr.TileAttribute(tosa_shape(multiples, output.dim_order)) - tosa_graph.addOperator(TosaOp.Op().TILE, [item_name], [output.name], attr) + tosa_graph.addOperator(TosaOp.Op().TILE, [inputs[0].name], [output.name], attr) diff --git a/backends/arm/test/ops/test_repeat.py b/backends/arm/test/ops/test_repeat.py index 20c57ba749c..de555e7c803 100644 --- a/backends/arm/test/ops/test_repeat.py +++ b/backends/arm/test/ops/test_repeat.py @@ -37,6 +37,7 @@ class Repeat(torch.nn.Module): (torch.randn(3), (2, 2)), (torch.randn(3), (1, 2, 3)), (torch.randn((3, 3)), (2, 2, 2)), + (torch.randn((3, 3, 3)), (2, 1, 2, 4)), ] def forward(self, x: torch.Tensor, multiples: Sequence): @@ -106,12 +107,20 @@ def test_repeat_tosa_MI(self, test_input, multiples): def test_repeat_tosa_BI(self, test_input, multiples): self._test_repeat_tosa_BI_pipeline(self.Repeat(), (test_input, multiples)) - @parameterized.expand(Repeat.test_parameters) + @parameterized.expand(Repeat.test_parameters[:-1]) def test_repeat_u55_BI(self, test_input, multiples): self._test_repeat_ethosu_pipeline( common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples) ) + # Final test requires transpose which is not supported on u55. + @parameterized.expand(Repeat.test_parameters[-1:]) + @unittest.expectedFailure + def test_repeat_u55_BI_xfails(self, test_input, multiples): + self._test_repeat_ethosu_pipeline( + common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples) + ) + @parameterized.expand(Repeat.test_parameters) def test_repeat_u85_BI(self, test_input, multiples): self._test_repeat_ethosu_pipeline( diff --git a/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py b/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py new file mode 100644 index 00000000000..d249c18ec85 --- /dev/null +++ b/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py @@ -0,0 +1,74 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import unittest + +import torch +from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import ( + UnsqueezeBeforeRepeatPass, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.xnnpack.test.tester.tester import RunPasses + + +class Repeat(torch.nn.Module): + """ + Basic repeat model. + """ + + def forward(self, x: torch.Tensor): + return x.repeat(2, 2, 2, 2) + + +class TestUnsqueezeBeforeRepeatPass(unittest.TestCase): + def test_tosa_MI_insert_view(self): + """ + When rank(input) != number of repeated dimensions (=4 in Repeat module), + insert view. + """ + module = Repeat() + inputs = (torch.rand((2, 3, 4)),) + test_pass_stage = RunPasses([UnsqueezeBeforeRepeatPass]) + ( + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .to_edge() + .check(["aten_repeat_default"]) + .check_not(["aten_view_copy_default"]) + .run_passes(test_pass_stage) + .check(["aten_repeat_default", "aten_view_copy_default"]) + ) + ) + + def test_tosa_MI_dont_insert_view(self): + """ + When rank(input) == number of repeated dimensions (=4 in Repeat module), + DON'T insert view. + """ + module = Repeat() + inputs = (torch.rand((2, 3, 4, 1)),) + test_pass_stage = RunPasses([UnsqueezeBeforeRepeatPass]) + ( + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .to_edge() + .check(["aten_repeat_default"]) + .check_not(["aten_view_copy_default"]) + .run_passes(test_pass_stage) + .check(["aten_repeat_default"]) + .check_not(["aten_view_copy_default"]) + ) + )