From d0c08ca02cb4a9d4772315d6ac4152c05f66770c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Mon, 16 Jun 2025 13:39:30 +0200 Subject: [PATCH] Arm backend: Move rescales from SUM visitor to pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In the SUM node visitor, an INT8->INT32 RESCALE node is inserted prior to the SUM node; similarly, an INT32->INT8 RESCALE node is inserted after. This patch moves the insertion to `InsertRescaleInt32Pass`. Since SUM is decomposed, insertion of RESCALE nodes should be carried out before `DecomposeSumPass` (which decomposes SUM into a chain of single dim SUMs). The ordering is important to avoid redundant INT8/INT32 RESCALE nodes being inserted between each SUM node in the chain after decomposition. Only one INT8->INT32 RESCALE is needed before the chain, and an INT32->INT8 after it; between the SUM nodes in the chain, the edges are already in the correct INT32 data type. Signed-off-by: Martin Lindström Change-Id: I86dd5c34b50ca6cbba6ad98e1490c9b7effc3b3b --- backends/arm/_passes/arm_pass_manager.py | 5 +- backends/arm/_passes/decompose_sum_pass.py | 2 +- backends/arm/_passes/insert_rescales_pass.py | 10 ++- .../operator_support/reduce_sum_support.py | 9 ++- backends/arm/operators/op_sum.py | 66 +------------------ backends/arm/test/ops/test_sum.py | 1 + .../passes/test_insert_rescale_i32_pass.py | 55 ++++++++++++---- 7 files changed, 64 insertions(+), 84 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b579d910752..98240f6dc1d 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -194,7 +194,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) - self.add_pass(DecomposeSumPass()) self.add_pass(DecomposeCumsumPass(exported_program)) self.add_pass(Conv1dUnsqueezePass()) self.add_pass(DecomposeMaxPool2DPass()) @@ -215,10 +214,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(RewriteMatmulPass()) self.add_pass(RewriteUpsamplePass()) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) + self.add_pass(InsertRescaleInt32Pass()) + self.add_pass(DecomposeSumPass()) self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) - self.add_pass(InsertRescaleInt32Pass()) self.validate_constraints_mandatory() return self._transform(exported_program.graph_module) @@ -361,7 +361,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(ConvertMinMaxPass()) self.add_pass(ReplaceInfValues()) - self.add_pass(DecomposeSumPass()) if not self.tosa_spec.is_U55_subset: # Uses where which is not supported on Ethos-U55 diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index 59c352a0e07..989d299a7e8 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -83,7 +83,7 @@ def call_operator(self, op, args, kwargs, meta): if not keepdims: shape = list(meta["val"].size()) input_node = super().call_operator( - view_op, (input_node, shape), kwargs, meta, updated=True + view_op, (input_node, shape), {}, meta, updated=True ) return input_node diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index a7fa614c8c3..81226207296 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -10,6 +10,7 @@ import torch from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_output_qparams, ) @@ -84,7 +85,11 @@ class InsertRescaleInt32Pass(ArmPass): parameters. """ - _passes_required_after: Set[Type[ExportPass]] = set() + # SUM must be decomposed after this pass to prevent insertion of RESCALE + # nodes between each subsequent SUM node after decomposition. RESCALE nodes + # should only be inserted before and after the SUM node prior to its + # decomposition. + _passes_required_after: Set[Type[ExportPass]] = {DecomposeSumPass} included_targets = [ exir_ops.edge.aten.abs.default, @@ -96,6 +101,7 @@ class InsertRescaleInt32Pass(ArmPass): exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sum.dim_IntList, ] def _int32_qargs(self, s): @@ -138,6 +144,7 @@ def _get_inputs_rescaled_qparams( } elif target in [ exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sum.dim_IntList, ]: # The input scales do not need to be adjusted for these ops; they # can remain the same. @@ -160,6 +167,7 @@ def _get_output_qparams( exir_ops.edge.aten.abs.default, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.sum.dim_IntList, ]: # The op has not altered the scale; the output scale is equal to # the operands' scales. diff --git a/backends/arm/operator_support/reduce_sum_support.py b/backends/arm/operator_support/reduce_sum_support.py index 4ff8f54ad69..76d1ba7bf36 100644 --- a/backends/arm/operator_support/reduce_sum_support.py +++ b/backends/arm/operator_support/reduce_sum_support.py @@ -29,8 +29,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): # U55 case, Vela 4.2.0 (25.02 release) input_shape = node.all_input_nodes[0].meta["val"].shape - dim_list = cast(list[int], node.args[1]) - dim_list = [dim % len(input_shape) for dim in dim_list] + + if node.args[1] is None: + # Dim is allowed to be None, which means to sum all dimensions + dim_list = list(range(len(input_shape))) + else: + dim_list = cast(list[int], node.args[1]) + dim_list = [dim % len(input_shape) for dim in dim_list] for dim in dim_list: if not 1 <= input_shape[dim] <= 65536: diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index 3f637d18390..98a2685e735 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -7,8 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils -import executorch.backends.arm.tosa.utils as tutils import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( @@ -25,69 +23,14 @@ @register_node_visitor -class SumVisitor_INT(NodeVisitor): +class SumVisitor(NodeVisitor): target = "aten.sum.dim_IntList" tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+FP"), TosaSpecification.create_from_string("TOSA-1.0+INT"), ] - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, 3) - validate_same_dtype(self.target, [inputs[0], output], ts) - - tensor = inputs[0] - input_shape = list(tensor.shape) - dim = int(inputs[1].number % len(input_shape)) - - output_shape = input_shape - output_shape[dim] = 1 # Output shape is input shape with dim reduced - - # Rescale input to 32 bit - rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32( - tosa_graph, [tensor], node, self.tosa_spec - ) - - attr = ts.TosaSerializerAttribute() - attr.ReduceSumAttribute(tensor.dim_order.index(dim)) - - intermediate = tosa_graph.addIntermediate( - tutils.tosa_shape(output_shape, tensor.dim_order), - dtype=ts.DType.INT32, - ) - - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().REDUCE_SUM, - [rescaled_inputs[0].name], - [intermediate.name], - attr, - ) - - tqutils.insert_rescale_op_to_int8( - tosa_graph, intermediate, scale, node, self.tosa_spec - ) - - -@register_node_visitor -class SumVisitor_FP(SumVisitor_INT): - # inheriting 'target' from INT class - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - def define_node( self, node: Node, @@ -102,9 +45,6 @@ def define_node( input_shape = list(tensor.shape) dim = int(inputs[1].number % len(input_shape)) - output_shape = input_shape - output_shape[dim] = 1 # Output shape is input shape with dim reduced - attr = ts.TosaSerializerAttribute() attr.ReduceSumAttribute(tensor.dim_order.index(dim)) @@ -112,7 +52,7 @@ def define_node( node, tosa_graph, ts.TosaOp.Op().REDUCE_SUM, - [tensor.name], + [inputs[0].name], [output.name], attr, ) diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py index 13c1e029032..f0af9a022e8 100644 --- a/backends/arm/test/ops/test_sum.py +++ b/backends/arm/test/ops/test_sum.py @@ -35,6 +35,7 @@ class Sum(torch.nn.Module): "4d_dim_3_keep": lambda: (torch.rand(1, 2, 3, 4), 3, True), "4d_dims_keep": lambda: (torch.rand(1, 2, 8, 8), [2, 3, 0], True), "dim_None": lambda: (torch.rand(10), None, True), + "dim_None_4d_tensor": lambda: (torch.rand(10, 3, 2, 1), None, True), } def forward(self, x: torch.Tensor, dim: int, keepdim: bool): diff --git a/backends/arm/test/passes/test_insert_rescale_i32_pass.py b/backends/arm/test/passes/test_insert_rescale_i32_pass.py index 66f09ba89a9..2f625b955ce 100644 --- a/backends/arm/test/passes/test_insert_rescale_i32_pass.py +++ b/backends/arm/test/passes/test_insert_rescale_i32_pass.py @@ -13,14 +13,11 @@ from executorch.backends.arm.test.tester.test_pipeline import PassPipeline -class NeedsRescaleOps(torch.nn.Module): +class MultipleOpsModel(torch.nn.Module): """A module containing ops that require INT32 inputs/outputs.""" input_t = Tuple[torch.Tensor, torch.Tensor] - def __init__(self): - super().__init__() - def forward(self, x, y): a = x * y b = torch.maximum(a, y) @@ -39,19 +36,41 @@ def get_inputs(self, dtype) -> input_t: else: raise ValueError("Not a valid input dtype for model") + def get_num_expected_rescales(self): + # "number of op nodes with i8 output" + "number of i8 node inputs" + return 3 + 7 -def test_insert_rescales(): - module = NeedsRescaleOps() - input_t = Tuple[torch.Tensor, torch.Tensor] + +class SumModel(torch.nn.Module): + input_t = Tuple[torch.Tensor] + + def forward(self, x): + a = torch.sum(x, 2, keepdim=True) # (1, 2, 1, 4) + b = torch.sum(a, [1, 3], keepdim=True) # (1, 1, 1, 1) + c = torch.sum(b, [0, 2], keepdim=False) # (1, 1) + return c + + def get_inputs(self, dtype) -> input_t: + if dtype == torch.float32: + return (torch.rand(1, 2, 3, 4),) + elif dtype == torch.int32: + return (torch.randint(0, 10, (1, 2, 3, 4), dtype=torch.int32),) + else: + raise ValueError("Not a valid input dtype for model") + + def get_num_expected_rescales(self): + # Two RESCALE nodes per SUM node + return 6 + + +def _test_model_with_f32_data(model): ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} ops_after = { - # "number of op nodes with i8 output" + "number of i8 node inputs" - "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3 - + 7, + "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": model.get_num_expected_rescales(), } - pipeline = PassPipeline[input_t]( - module, - module.get_inputs(torch.float32), + pipeline = PassPipeline[model.input_t]( + model, + model.get_inputs(torch.float32), quantize=True, ops_not_before_pass=ops_not_before, ops_after_pass=ops_after, @@ -61,8 +80,16 @@ def test_insert_rescales(): pipeline.run() +def test_insert_rescales_sum_model(): + _test_model_with_f32_data(SumModel()) + + +def test_insert_rescales_multiple_ops_model(): + _test_model_with_f32_data(MultipleOpsModel()) + + def test_dont_insert_rescales(): - module = NeedsRescaleOps() + module = MultipleOpsModel() input_t = Tuple[torch.Tensor, torch.Tensor] ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} # All inputs are already i32. Rescales should not be added.