From 164a0d21e2c058250cc7cc7b24bacc1ceb97e1f5 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 20 Jan 2025 16:37:13 +0100 Subject: [PATCH 1/2] Support testing passes needing ExportedProgram ExportPasses that need to be initated with an ExportedProgram can currently not be tested in a convenient way. This patch subclasses RunPasses, adding a parameter, `passes_with_exported_program`, that can be used just like `pass_list` but are initiated with an exported program before they are run. The functionality is tested in new tests for the CastInt64Pass and InsertTableOpsPass Signed-off-by: Erik Lundell Change-Id: I1712d86abe7cc3672c343db568df1264c0b9133e --- .../arm/test/passes/test_cast_int64_pass.py | 44 +++++++++++++++ .../test/passes/test_insert_table_ops_pass.py | 55 +++++++++++++++++++ backends/arm/test/tester/arm_tester.py | 50 ++++++++++++++++- 3 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 backends/arm/test/passes/test_cast_int64_pass.py create mode 100644 backends/arm/test/passes/test_insert_table_ops_pass.py diff --git a/backends/arm/test/passes/test_cast_int64_pass.py b/backends/arm/test/passes/test_cast_int64_pass.py new file mode 100644 index 00000000000..fdfab1f3af8 --- /dev/null +++ b/backends/arm/test/passes/test_cast_int64_pass.py @@ -0,0 +1,44 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass + +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses + + +class Int64Model(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return x + 3 + + def get_inputs(self): + return (torch.rand(4),) + + +class TestCastInt64Pass(unittest.TestCase): + + def test_int64_model(self): + module = Int64Model() + test_pass_stage = RunPasses(passes_with_exported_program=[CastInt64ToInt32Pass]) + tester = ( + ArmTester( + module, + example_inputs=module.get_inputs(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), + ) + .export() + .to_edge() + .run_passes(test_pass_stage) + .run_method_and_compare_outputs() + ) + exported_program = tester.get_artifact("RunPasses").exported_program() + for state in exported_program.state_dict: + assert exported_program.state_dict[state].dtype != torch.int64 diff --git a/backends/arm/test/passes/test_insert_table_ops_pass.py b/backends/arm/test/passes/test_insert_table_ops_pass.py new file mode 100644 index 00000000000..c0a9235fa6e --- /dev/null +++ b/backends/arm/test/passes/test_insert_table_ops_pass.py @@ -0,0 +1,55 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + FoldAndAnnotateQParamsPass, +) +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass + +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses + + +class Sigmoid(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return x.sigmoid() + + def get_inputs(self): + return (torch.rand(4),) + + +class TestInsertTablePass(unittest.TestCase): + + def test_insert_table_tosa_BI(self): + module = Sigmoid() + test_pass_stage = RunPasses( + [FoldAndAnnotateQParamsPass], + passes_with_exported_program=[InsertTableOpsPass], + ) + ( + ArmTester( + module, + example_inputs=module.get_inputs(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), + ) + .quantize() + .export() + .to_edge() + .run_passes(test_pass_stage) + .check("tosa._table") + .check_count( + { + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, + } + ) + .check_not(["aten_sigmoid_default"]) + ) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index a5e7267b372..8d9ac453eec 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -8,7 +8,7 @@ import os from collections import Counter from pprint import pformat -from typing import Iterable, List, Optional, Tuple, Union +from typing import Callable, Iterable, List, Optional, Tuple, Type, Union import executorch.backends.xnnpack.test.tester.tester as tester @@ -41,10 +41,18 @@ from executorch.backends.xnnpack.test.tester import Tester from executorch.devtools.backend_debug import get_delegation_info -from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager +from executorch.exir import ( + EdgeCompileConfig, + EdgeProgramManager, + ExecutorchProgramManager, + ExportedProgram, +) +from executorch.exir.backend.backend_api import validation_disabled from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import Partitioner from executorch.exir.lowered_backend_module import LoweredBackendModule +from executorch.exir.pass_base import ExportPass +from executorch.exir.program._program import _update_exported_program_graph_module from tabulate import tabulate from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec @@ -158,6 +166,44 @@ def run_artifact(self, inputs): return super().run_artifact(inputs) +class RunPasses(tester.RunPasses): + + def __init__( + self, + pass_list: Optional[List[Type[ExportPass]]] = None, + pass_functions: Optional[List[Callable]] = None, + passes_with_exported_program: Optional[List[Type[ExportPass]]] = None, + ): + """Passes are run in the order they are passed: first pass_list, second pass_functions, + and lastly passes_with_exported_program.""" + self.pass_with_exported_program = passes_with_exported_program + super().__init__(pass_list, pass_functions) + + def run( + self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None + ) -> None: + if self.pass_with_exported_program is not None: + self.pass_functions = self.pass_functions or [] # type: ignore + + # pass_function list from superclass expects functions that take in + # and return ExportedPrograms. + # Create a wrapper to fit pass_with_exported_program into this. + def wrap_ep_pass(ep_pass: Type[ExportPass]): + def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram: + pass_result = ep_pass(ep).call(ep.graph_module) + with validation_disabled(): + return _update_exported_program_graph_module( + ep, pass_result.graph_module + ) + + return wrapped_ep_pass + + self.pass_functions.extend( + [wrap_ep_pass(ep_pass) for ep_pass in self.pass_with_exported_program] + ) + super().run(artifact, inputs) + + class InitialModel(tester.Stage): def __init__(self, model: torch.nn.Module): self.model = model From 7ee9bc2201731b38429986a94781ba827f411c93 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 20 Jan 2025 10:35:34 +0100 Subject: [PATCH 2/2] Add pass for fusing batchnorm into conv The pass differs from existing fuse passes since they use the get_attr node which is not supported by ArmBackend. Instead, we update the existing parameters. Also adds tests. Change-Id: Iad6d70e632191d74d96df62b1837d37fe60e7d3a --- backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/_passes/fuse_batchnorm2d_pass.py | 128 ++++++++++++++ backends/arm/test/ops/test_conv_combos.py | 31 ++-- .../test/passes/test_fuse_batchnorm_pass.py | 158 ++++++++++++++++++ 4 files changed, 309 insertions(+), 10 deletions(-) create mode 100644 backends/arm/_passes/fuse_batchnorm2d_pass.py create mode 100644 backends/arm/test/passes/test_fuse_batchnorm_pass.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 2edbaafcb77..6beda533bf7 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -42,6 +42,7 @@ QuantizeFullArgument, RetraceFoldedDtypesPass, ) +from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found] FuseQuantizedActivationPass, ) @@ -126,6 +127,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertMeanDimToAveragePoolPass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeSoftmaxesPass()) + self.add_pass(FuseBatchnorm2DPass(exported_program)) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeFullArgument()) diff --git a/backends/arm/_passes/fuse_batchnorm2d_pass.py b/backends/arm/_passes/fuse_batchnorm2d_pass.py new file mode 100644 index 00000000000..6a5ece2e446 --- /dev/null +++ b/backends/arm/_passes/fuse_batchnorm2d_pass.py @@ -0,0 +1,128 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch._export.utils import get_buffer, get_param +from torch.fx import Node +from torch.nn.utils.fusion import fuse_conv_bn_weights + + +class FuseBatchnorm2DPass(ExportPass): + """Fuses the pattern convolution -> batchnorm by updating + the weights and bias of the convolution and removing the batchnorm. + """ + + def __init__(self, exported_program: ExportedProgram): + self.exported_program = exported_program + super().__init__() + + def is_fuseable_conv_bn(self, node: Node): + """Returns True if node is a batchnorm that can be fused into + a parent convolution.""" + if node.op != "call_function": + return False + if node.target not in ( + exir_ops.edge.aten._native_batch_norm_legit, + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + ): + return False + conv = node.all_input_nodes[0] + if conv.target != exir_ops.edge.aten.convolution.default: + return False + # Batchnorm users are getitem, we can only handle those that get first element. + for user in node.users: + get_index = user.args[1] + if get_index != 0: + return False + # Since we change the output of the conv, fuse only if it has single user. + if len(conv.users) > 1: + return False + # For similar reasons, only fuse if conv parameters have single user. + if len(conv.all_input_nodes[1].users) > 1: + return False + if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1: + return False + return True + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 + modified = False + for node in graph_module.graph.nodes: + if not self.is_fuseable_conv_bn(node): + continue + + def get_param_or_none(arg) -> torch.nn.Parameter | None: + """get_param but check if arg is none first.""" + return ( + get_param(self.exported_program, arg) if arg is not None else None + ) + + # Get weight, bias, mean, var and epsilon from the batchnorm + bn = node + conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5] + bn_weight = get_param_or_none(bn_weight_node) + bn_bias = get_param_or_none(bn_bias_node) + + running_mean = get_buffer(self.exported_program, bn_mean_node) + running_var = get_buffer(self.exported_program, bn_var_node) + if running_mean is None or running_var is None: + raise ValueError( + "Parameters running_mean and running_var of batchnorm can't be None." + ) + epsilon = bn.args[-1] + + # Get weight and bias from conv + conv_weight_node, conv_bias_node = conv.args[1:3] + conv_weight = get_param(self.exported_program, conv_weight_node) + conv_bias = get_param_or_none(conv_bias_node) + if conv_weight is None: + raise ValueError("Parameter weight of convolution can't be None.") + + # Compute conv parameters folded with batchnorm + fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights( + conv_weight, + conv_bias, + running_mean, + running_var, + epsilon, + bn_weight, + bn_bias, + ) + + # Set the conv parameters to fused value + def try_set_param( + param_node: Node | None, param_value: torch.nn.Parameter + ) -> bool: + """set_param but check if param_node is None first. Return True if param was set successfully, otherwise False.""" + if param_node is not None: + param_name = ( + self.exported_program.graph_signature.inputs_to_parameters[ + param_node.name + ] + ) + self.exported_program.state_dict[param_name] = param_value + return True + return False + + try_set_param(conv_weight_node, fused_conv_weight) + if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param( + bn_bias_node, fused_conv_bias + ): + # Conv didn't have bias but batchnorm did, steal bias from batchnorm. + conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:]) + conv.args = conv_args + + # Erasing nodes is handled by dead-code elimination. + for user in bn.users: + user.replace_all_uses_with(conv) + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module=graph_module, modified=modified) diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 8352727a1c3..fe58d1265f4 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -16,6 +16,7 @@ from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.backend_details import CompileSpec from parameterized import parameterized +from torch.nn.parameter import Parameter logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -112,12 +113,16 @@ class ComboConvBatchnormRelu6(torch.nn.Module): "executorch_exir_dialects_edge__ops_aten_hardtanh_default", ] - def __init__(self): + def __init__(self, affine: bool): super().__init__() self.conv2d = torch.nn.Conv2d( in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1 ) - self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=False) + self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine) + self.batch_norm2d.running_mean = torch.rand(3) + self.batch_norm2d.running_var = torch.rand(3) + self.batch_norm2d.weight = Parameter(torch.rand(3)) + self.batch_norm2d.bias = Parameter(torch.rand(3)) self.relu6 = torch.nn.ReLU6() def get_inputs(self) -> Tuple[torch.Tensor]: @@ -289,24 +294,30 @@ def test_conv_meandim_u85_BI(self): ############################## ## Conv + batch norm + relu ## ############################## - def test_conv_batchnorm_relu6_tosa_MI(self): - model = ComboConvBatchnormRelu6() + affine_params = [("affine", True), ("_no_affine", False)] + + @parameterized.expand(affine_params) + def test_conv_batchnorm_relu6_tosa_MI(self, test_suffix, affine): + model = ComboConvBatchnormRelu6(affine) self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs()) - def test_conv_batchnorm_relu6_tosa_BI(self): - model = ComboConvBatchnormRelu6() + @parameterized.expand(affine_params) + def test_conv_batchnorm_relu6_tosa_BI(self, test_suffix, affine): + model = ComboConvBatchnormRelu6(affine) self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs()) + @parameterized.expand(affine_params) @pytest.mark.corstone_fvp - def test_conv_batchnorm_relu6_u55_BI(self): - model = ComboConvBatchnormRelu6() + def test_conv_batchnorm_relu6_u55_BI(self, test_suffix, affine): + model = ComboConvBatchnormRelu6(affine) self._test_conv_combo_ethos_BI_pipeline( model, common.get_u55_compile_spec(), model.get_inputs() ) + @parameterized.expand(affine_params) @pytest.mark.corstone_fvp - def test_conv_batchnorm_relu_u85_BI(self): - model = ComboConvBatchnormRelu6() + def test_conv_batchnorm_relu_u85_BI(self, test_suffix, affine): + model = ComboConvBatchnormRelu6(affine) self._test_conv_combo_ethos_BI_pipeline( model, common.get_u85_compile_spec(), diff --git a/backends/arm/test/passes/test_fuse_batchnorm_pass.py b/backends/arm/test/passes/test_fuse_batchnorm_pass.py new file mode 100644 index 00000000000..09f8f578fc2 --- /dev/null +++ b/backends/arm/test/passes/test_fuse_batchnorm_pass.py @@ -0,0 +1,158 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import unittest + +import torch +from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses +from parameterized import parameterized + + +class MergeOneOfTwoBN(torch.nn.Module): + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + } + + def __init__(self, affine: bool): + super().__init__() + self.conv2d = torch.nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1 + ) + self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine) + self.batch_norm2d.running_mean = torch.rand(3) + self.batch_norm2d.running_var = torch.rand(3) + if affine: + self.batch_norm2d.weight = torch.nn.Parameter(torch.rand(3)) + self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3)) + self.relu6 = torch.nn.ReLU6() + + def get_inputs(self) -> tuple[torch.Tensor]: + return (torch.randn(1, 3, 256, 256),) + + def forward(self, x): + x = self.conv2d(x) + x = self.batch_norm2d(x) + x = self.relu6(x) + x = self.batch_norm2d(x) + return x + + +class MergeTwosOfTwoBN(torch.nn.Module): + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 2, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 0, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 2, + } + + def __init__(self, affine: bool): + super().__init__() + self.conv2d = torch.nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1 + ) + self.conv2d2 = torch.nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1 + ) + self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine) + self.batch_norm2d.running_mean = torch.rand(3) + self.batch_norm2d.running_var = torch.rand(3) + if affine: + self.batch_norm2d.weight = torch.nn.Parameter(torch.rand(3)) + self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3)) + self.relu6 = torch.nn.ReLU6() + + def get_inputs(self) -> tuple[torch.Tensor]: + return (torch.randn(1, 3, 256, 256),) + + def forward(self, x): + x = self.conv2d(x) + x = self.batch_norm2d(x) + x = self.relu6(x) + x = self.conv2d2(x) + x = self.batch_norm2d(x) + return x + + +class MergeNoBN(torch.nn.Module): + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 3, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 3, + } + + def __init__(self, affine: bool): + super().__init__() + self.conv2d = torch.nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1 + ) + self.conv2d2 = torch.nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1 + ) + self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine) + self.batch_norm2d.running_mean = torch.rand(3) + self.batch_norm2d.running_var = torch.rand(3) + if affine: + self.batch_norm2d.weight = torch.nn.Parameter(torch.rand(3)) + self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3)) + self.relu6 = torch.nn.ReLU6() + + def get_inputs(self) -> tuple[torch.Tensor]: + return (torch.randn(1, 3, 256, 256),) + + def forward(self, x): + x1 = self.conv2d(x) + x = self.batch_norm2d(x1) # Can't be fused since x1 has multiple users + x = self.relu6(x) + y = self.conv2d2(x1) + z = self.conv2d2(x) + a = self.batch_norm2d( + y + ) # Can't be fused since paramters of conv2d2 have multiple users. + + return z, a + + +modules = [ + MergeOneOfTwoBN(True), + MergeOneOfTwoBN(False), + MergeTwosOfTwoBN(True), + MergeNoBN(True), +] + + +class TestFuseBatchnormPass(unittest.TestCase): + + @parameterized.expand(modules) + def test_fuse_batchnorm_tosa_MI(self, module): + """Test various cases where the batchnorm should and shouldn't be fused.""" + inputs = module.get_inputs() + test_pass_stage = RunPasses(passes_with_exported_program=[FuseBatchnorm2DPass]) + ( + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), + ) + .export() + .to_edge() + .check_count(module.ops_before_pass) + .run_passes(test_pass_stage) + .check_count(module.ops_after_pass) + .run_method_and_compare_outputs() + ) + )