From 794c155f611798ac20d8366c83ad55d95776338b Mon Sep 17 00:00:00 2001 From: Lukas Sztefek Date: Wed, 30 Apr 2025 10:24:52 +0200 Subject: [PATCH 1/2] Create NeutronAtenPassManager with initial BatchNorm fusing passes --- .../fuse_batch_norm_with_conv_pass.py | 146 ++++++++++++ .../fuse_batch_norm_with_linear_pass.py | 150 ++++++++++++ .../aten_passes/neutron_aten_pass_manager.py | 40 ++++ backends/nxp/tests/test_batch_norm_fusion.py | 215 ++++++++++++++++++ 4 files changed, 551 insertions(+) create mode 100644 backends/nxp/aten_passes/fuse_batch_norm_with_conv_pass.py create mode 100644 backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py create mode 100644 backends/nxp/aten_passes/neutron_aten_pass_manager.py create mode 100644 backends/nxp/tests/test_batch_norm_fusion.py diff --git a/backends/nxp/aten_passes/fuse_batch_norm_with_conv_pass.py b/backends/nxp/aten_passes/fuse_batch_norm_with_conv_pass.py new file mode 100644 index 00000000000..581a8163406 --- /dev/null +++ b/backends/nxp/aten_passes/fuse_batch_norm_with_conv_pass.py @@ -0,0 +1,146 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional + +import torch +from torch.export.unflatten import _assign_attr, _AttrKind +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.nn.parameter import Parameter +from torch.nn.utils import fuse_conv_bn_weights + + +class FuseBatchNormWithConvPass(PassBase): + """The executorch batch normalization carries out the following computation [1]. + + (x - mean) / sqrt(var + eps) * W + B + + Which can be expressed as + + x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps))) + + So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static, + and the terms can be precomputed. If there is a `Conv` operator before the batch normalization, this scale and + bias can be statically integrated into the weights and bias of the `Conv`, which allows the batch norm to be + completely removed. + + + │ + ┌─────────────▼─────────────┐ + │ aten.conv1d | aten.conv2d │ + └─────────────┬─────────────┘ + │ │ + ┌─────────────────────▼─────────────────────┐ replace with ┌─────────────▼─────────────┐ + │ aten.batch_norm │ ──────────────► │ aten.conv1d | aten.conv2d │ + └─────────────────────┬─────────────────────┘ └─────────────┬─────────────┘ + │ ▼ + ▼ + + [1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128 + """ + + def _get_tensor_constant_from_node(self, graph_module, node) -> Parameter | None: + """Get the static data from a given node. If it doesn't have any data, return `None`.""" + if node is None or node.op != "get_attr": + return None + + target_atoms = node.target.split(".") + attr_itr = graph_module + for atom in target_atoms: + if not hasattr(attr_itr, atom): + return None + attr_itr = getattr(attr_itr, atom) + return attr_itr + + def call(self, graph_module: GraphModule) -> Optional[PassResult]: + def _is_batch_norm(node_: Node) -> bool: + return ( + node_.op == "call_function" + and node_.target == torch.ops.aten.batch_norm.default + ) + + def _is_conv(node_: Node): + is_conv = node_.op == "call_function" and node_.target in ( + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ) + has_single_user = len(node.users) == 1 + + return is_conv and has_single_user + + made_changes = False + + if not any(map(_is_batch_norm, graph_module.graph.nodes)): + return PassResult( + graph_module, made_changes + ) # No batch norm nodes in the model. + + for node in graph_module.graph.nodes: + if not _is_batch_norm(node): + continue # Not BatchNorm. + + bn_node = node + + if not _is_conv(bn_node.args[0]): + continue # Something other than a Conv node comes before the BatchNorm. + + conv_node = bn_node.args[0] + conv_weight_node = conv_node.args[1] + conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None + + # conv args: input, weight, bias, stride, padding, dilation, ... + conv_w = self._get_tensor_constant_from_node(graph_module, conv_weight_node) + conv_b = self._get_tensor_constant_from_node(graph_module, conv_bias_node) + + # batch norm args: input, weight, bias, running mean, training, running var, momentum, eps + bn_w = self._get_tensor_constant_from_node(graph_module, bn_node.args[1]) + bn_b = self._get_tensor_constant_from_node(graph_module, bn_node.args[2]) + bn_rm = self._get_tensor_constant_from_node(graph_module, bn_node.args[3]) + bn_rv = self._get_tensor_constant_from_node(graph_module, bn_node.args[4]) + bn_eps = bn_node.args[7] + + if any( + t is None for t in (conv_w, bn_rm, bn_rv) + ): # The other inputs can be None. + continue # The data is not static. Leave this BatchNorm as is (probably a rare case). + fused_weight, fused_bias = fuse_conv_bn_weights( + conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b + ) + + # Update the weight and bias for Conv. + conv_args = list(conv_node.args) + if len(conv_args) == 2: + # Fill in the default bias argument. + conv_args.append(None) + + weight_attr_name = conv_weight_node.target + _assign_attr( + fused_weight, graph_module, weight_attr_name, _AttrKind.PARAMETER + ) + + if conv_bias_node is not None: + bias_attr_name = conv_bias_node.target + _assign_attr( + fused_bias, graph_module, str(bias_attr_name), _AttrKind.PARAMETER + ) + else: + # The Conv doesn't have a bias. Create a new one. + bias_attr_name = weight_attr_name + "_bias" + _assign_attr( + fused_bias, graph_module, bias_attr_name, _AttrKind.PARAMETER + ) + with graph_module.graph.inserting_before(conv_node): + get_bias_node = graph_module.graph.get_attr(bias_attr_name) + + conv_args[2] = get_bias_node + + conv_node.args = tuple(conv_args) + + # Replace the uses of the BatchNorm with the Conv. + bn_node.replace_all_uses_with(conv_node) + + made_changes = True + + return PassResult(graph_module, made_changes) diff --git a/backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py b/backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py new file mode 100644 index 00000000000..b6ab4489bb8 --- /dev/null +++ b/backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py @@ -0,0 +1,150 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional + +import torch +from torch.export.unflatten import _assign_attr, _AttrKind +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.nn.parameter import Parameter +from torch.nn.utils import fuse_linear_bn_weights + + +class FuseBatchNormWithLinearPass(PassBase): + """The executorch batch normalization carries out the following computation [1]. + + (x - mean) / sqrt(var + eps) * W + B + + Which can be expressed as + + x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps))) + + So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static, + and the terms can be precomputed. If there is a `Linear` operator before the batch normalization, this scale + and bias can be statically integrated into the weights and bias of the `Linear`, which allows the batch norm + to be completely removed. + + + │ + ┌──────▼──────┐ + │ aten.linear │ + └──────┬──────┘ + │ │ + ┌─────────────────────▼─────────────────────┐ replace with ┌──────▼──────┐ + │ aten.batch_norm │ ──────────────► │ aten.linear │ + └─────────────────────┬─────────────────────┘ └──────┬──────┘ + ▼ + + [1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128 + """ + + def _get_tensor_constant_from_node(self, graph_module, node) -> Parameter | None: + """Get the static data from a given node. If it doesn't have any data, return `None`.""" + if node is None or node.op != "get_attr": + return None + + target_atoms = node.target.split(".") + attr_itr = graph_module + for atom in target_atoms: + if not hasattr(attr_itr, atom): + return None + attr_itr = getattr(attr_itr, atom) + return attr_itr + + def call(self, graph_module: GraphModule) -> Optional[PassResult]: + def _is_batch_norm(node_: Node) -> bool: + return ( + node_.op == "call_function" + and node_.target == torch.ops.aten.batch_norm.default + ) + + def _is_linear(node_: Node): + is_linear = ( + node_.op == "call_function" + and node_.target == torch.ops.aten.linear.default + ) + has_single_user = len(node.users) == 1 + + return is_linear and has_single_user + + made_changes = False + + if not any(map(_is_batch_norm, graph_module.graph.nodes)): + return PassResult( + graph_module, made_changes + ) # No batch norm nodes in the model. + + for node in graph_module.graph.nodes: + if not _is_batch_norm(node): + continue # Not BatchNorm. + + bn_node = node + + if not _is_linear(bn_node.args[0]): + continue # Something other than a Linear node comes before the BatchNorm. + + linear_node = bn_node.args[0] + linear_weight_node = linear_node.args[1] + linear_bias_node = ( + linear_node.args[2] if len(linear_node.args) > 2 else None + ) + + linear_w = self._get_tensor_constant_from_node( + graph_module, linear_weight_node + ) + linear_b = self._get_tensor_constant_from_node( + graph_module, linear_bias_node + ) + + # batch norm args: input, weight, bias, running mean, training, running var, momentum, eps + bn_w = self._get_tensor_constant_from_node(graph_module, bn_node.args[1]) + bn_b = self._get_tensor_constant_from_node(graph_module, bn_node.args[2]) + bn_rm = self._get_tensor_constant_from_node(graph_module, bn_node.args[3]) + bn_rv = self._get_tensor_constant_from_node(graph_module, bn_node.args[4]) + bn_eps = bn_node.args[7] + + if any( + t is None for t in (linear_w, bn_w, bn_b, bn_rm, bn_rv) + ): # The Linear bias can be None. + continue # The data is not static. Leave this BatchNorm as is (probably a rare case). + fused_weight, fused_bias = fuse_linear_bn_weights( + linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b + ) + + # Update the weight and bias for Linear. + linear_args = list(linear_node.args) + if len(linear_args) == 2: + # Fill in the default bias argument. + linear_args.append(None) + + weight_attr_name = linear_weight_node.target + _assign_attr( + fused_weight, graph_module, weight_attr_name, _AttrKind.PARAMETER + ) + + if linear_bias_node is not None: + bias_attr_name = linear_bias_node.target + _assign_attr( + fused_bias, graph_module, str(bias_attr_name), _AttrKind.PARAMETER + ) + else: + # The Linear doesn't have a bias. Create a new one. + bias_attr_name = weight_attr_name + "_bias" + _assign_attr( + fused_bias, graph_module, bias_attr_name, _AttrKind.PARAMETER + ) + with graph_module.graph.inserting_before(linear_node): + get_bias_node = graph_module.graph.get_attr(bias_attr_name) + + linear_args[2] = get_bias_node + + linear_node.args = tuple(linear_args) + + # Replace the uses of the BatchNorm with the Linear. + bn_node.replace_all_uses_with(linear_node) + + made_changes = True + + return PassResult(graph_module, made_changes) diff --git a/backends/nxp/aten_passes/neutron_aten_pass_manager.py b/backends/nxp/aten_passes/neutron_aten_pass_manager.py new file mode 100644 index 00000000000..6c989b6c136 --- /dev/null +++ b/backends/nxp/aten_passes/neutron_aten_pass_manager.py @@ -0,0 +1,40 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch + +from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_conv_pass import ( + FuseBatchNormWithConvPass, +) +from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( + FuseBatchNormWithLinearPass, +) +from executorch.exir.pass_manager import PassManager +from torch import nn +from torch.fx.passes.infra.pass_base import PassResult + +PassType = list[type[Callable[[torch.fx.GraphModule], PassResult]]] + + +class NeutronAtenPassManager(PassManager): + + def __init__(self, passes: list[PassType] = None): + passes: list[PassType] = passes or [ + FuseBatchNormWithConvPass(), + FuseBatchNormWithLinearPass(), + ] + + super().__init__(passes) + + def __call__(self, module: nn.Module) -> PassResult: + pass_result: PassResult = super().__call__(module) + + graph_module = pass_result.graph_module + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return pass_result diff --git a/backends/nxp/tests/test_batch_norm_fusion.py b/backends/nxp/tests/test_batch_norm_fusion.py new file mode 100644 index 00000000000..c058543be2d --- /dev/null +++ b/backends/nxp/tests/test_batch_norm_fusion.py @@ -0,0 +1,215 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy + +import numpy as np +import pytest +import torch +from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( + NeutronAtenPassManager, +) +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import ( + AddMMConverter, + MMConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import OverrideSupportedTargets +from torch import nn + + +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(42) + np.random.seed(23) + + +class BatchNormModule(torch.nn.Module): + def __init__(self, input_rank: int, num_features: int, eps: float = 1e-5): + super().__init__() + match input_rank - 2: + case 1: + self.batch_norm = nn.BatchNorm1d(num_features, eps) + case 2: + self.batch_norm = nn.BatchNorm2d(num_features, eps) + case 3: + self.batch_norm = nn.BatchNorm3d(num_features, eps) + case _: + raise ValueError + self.eval() + + def forward(self, x): + return self.batch_norm(x) + + +class ConvBatchNormModule(torch.nn.Module): + def __init__( + self, bias: bool, input_rank: int, num_features: int, eps: float = 1e-5 + ): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=num_features, + out_channels=num_features, + kernel_size=3, + bias=bias, + ) + self.batch_norm = BatchNormModule(input_rank, num_features, eps) + self.eval() + + def forward(self, x): + x = self.conv(x) + return self.batch_norm(x) + + +class LinearBatchNormModule(torch.nn.Module): + def __init__( + self, + bias: bool, + input_rank: int, + fc_in_features: int, + fc_out_features: int, + eps: float = 1e-5, + ): + super().__init__() + self.linear = torch.nn.Linear(fc_in_features, fc_out_features, bias=bias) + self.batch_norm = BatchNormModule(input_rank, fc_out_features, eps) + self.eval() + + def forward(self, x): + x = self.linear(x) + return self.batch_norm(x) + + +@pytest.mark.parametrize( + "bias", [True, False], ids=lambda x: "Bias" if x else "No bias" +) +@pytest.mark.parametrize( + "input_shape", [[4, 6, 8], [2, 4, 6, 8]], ids=lambda x: f"{len(x)}D" +) +def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]): + example_input = (torch.ones(*input_shape),) + + module = ConvBatchNormModule(bias, len(input_shape), 4) + program = torch.export.export_for_training(module, example_input, strict=True) + og_module = program.module() + + pm = NeutronAtenPassManager() + graph_module_out = pm(deepcopy(program.module())).graph_module + + # Make sure the fusion worked. + og_nodes = list(program.graph.nodes) + transformed_nodes = list(graph_module_out.graph.nodes) + + assert len(og_nodes) == (11 if bias else 10) + assert og_nodes[9 if bias else 8].target.__name__ == "batch_norm.default" + + assert len(transformed_nodes) == 5 + assert not any( + node.op == "call_function" and "batch_norm" in node.target.__name__ + for node in transformed_nodes + ) + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = og_module(input_data).detach().numpy() + out2 = graph_module_out(input_data).detach().numpy() + assert np.allclose(out1, out2, atol=3.0e-7) + + +@pytest.mark.parametrize( + "bias", [True, False], ids=lambda x: "Bias" if x else "No bias" +) +def test_batch_norm_linear_fusing(bias: bool): + input_shape = (2, 4, 6, 8) + example_input = (torch.ones(*input_shape),) + + module = LinearBatchNormModule(bias, 4, input_shape[-1], input_shape[1]) + program = torch.export.export_for_training(module, example_input, strict=True) + og_module = program.module() + + pm = NeutronAtenPassManager() + graph_module_out = pm(deepcopy(program.module())).graph_module + + # Make sure the fusion worked. + og_nodes = list(og_module.graph.nodes) + transformed_nodes = list(graph_module_out.graph.nodes) + + assert len(og_nodes) == (11 if bias else 10) + assert og_nodes[8 if bias else 7].target.__name__ == "linear.default" + + assert len(transformed_nodes) == 5 + assert not any( + node.op == "call_function" and "batch_norm" in node.target.__name__ + for node in transformed_nodes + ) + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = og_module(input_data).detach().numpy() + out2 = graph_module_out(input_data).detach().numpy() + assert np.allclose(out1, out2, atol=1.2e-7) + + +@pytest.mark.parametrize( + "bias", [True, False], ids=lambda x: "Bias" if x else "No bias" +) +def test_batch_norm_conv_fusing__full_pipeline__1d(bias: bool): + input_shape = [4, 6, 8] + module = ConvBatchNormModule(bias, len(input_shape), 4) + + edge_program = to_quantized_edge_program( + module, tuple(input_shape) + ).exported_program() + nodes = list(edge_program.graph.nodes) + + assert ( + len(nodes) == 13 + ) # 1D Conv currently isn't delegated, because it doesn't get quantized. + assert not any( + node.op == "call_function" and "batch_norm" in node.target.__name__ + for node in nodes + ) + + +@pytest.mark.parametrize( + "bias", [True, False], ids=lambda x: "Bias" if x else "No bias" +) +def test_batch_norm_conv_fusing__full_pipeline__2d(bias: bool): + input_shape = [1, 4, 6, 8] + module = ConvBatchNormModule(bias, len(input_shape), 4) + + edge_program = to_quantized_edge_program( + module, tuple(input_shape) + ).exported_program() + nodes = list(edge_program.graph.nodes) + + assert len(nodes) == 7 + assert not any( + node.op == "call_function" and "batch_norm" in node.target.__name__ + for node in nodes + ) + + +@pytest.mark.parametrize( + "bias", [True, False], ids=lambda x: "Bias" if x else "No bias" +) +def test_batch_norm_linear_fusing__full_pipeline(bias: bool): + input_shape = (2, 4, 6, 8) + module = LinearBatchNormModule(bias, 4, input_shape[-1], input_shape[1]) + + # Don't delegate the Linear node, because there seems to be a bug with the NeutronConverter/NeutronPartitioner. + # But that doesn't affect the validity of this test. + with OverrideSupportedTargets(AddMMConverter, new_targets=[]): + with OverrideSupportedTargets(MMConverter, new_targets=[]): + edge_program = to_quantized_edge_program( + module, tuple(input_shape) + ).exported_program() + nodes = list(edge_program.graph.nodes) + + assert len(nodes) == 14 + assert not any( + node.op == "call_function" and "batch_norm" in node.target.__name__ + for node in nodes + ) From 21ed973d4468fd5db5ca797f30031a4913e99cca Mon Sep 17 00:00:00 2001 From: Lukas Sztefek Date: Wed, 30 Apr 2025 10:25:31 +0200 Subject: [PATCH 2/2] Integrate NeutronAtenPassManager passes into pipeline --- backends/nxp/quantizer/neutron_quantizer.py | 6 +++++- backends/nxp/tests/executorch_pipeline.py | 13 ++----------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index eff7f513cb9..c757d3a84fa 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -7,6 +7,9 @@ from typing import List, Optional, Tuple, Union import torch +from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( + NeutronAtenPassManager, +) from executorch.backends.nxp.quantizer.patterns import ( AddmmPattern, @@ -202,4 +205,5 @@ def __init__(self): def transform_for_annotation( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: - return model + pass_runner = NeutronAtenPassManager() + return pass_runner(model).graph_module diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index 6c452b99baf..d9ae9f427fd 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -8,9 +8,6 @@ from executorch import exir from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec - -# TODO (Robert Kalmar) Uncomment when NXP passes are ported to main -# from executorch.backends.nxp.pytorch_passes.nxp_pytorch_pass_manager import NXPPyTorchPassManager from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer from executorch.exir import ( EdgeCompileConfig, @@ -27,7 +24,7 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]): quantizer = NeutronQuantizer() m = prepare_pt2e(model, quantizer) - for _i, data in enumerate(calibration_inputs): + for data in calibration_inputs: m(*data) m = convert_pt2e(m) @@ -48,14 +45,8 @@ def to_quantized_edge_program( model, example_input, strict=True ) - # TODO(Robert Kalmar) uncoment when NXP passes are ported to main - # Run pre-processing passes of the float32 aten dialect program. - # pass_manager = NXPPyTorchPassManager(exir_program_aten) - # pass_manager.run() # All passes by default. - - exir_program_aten_module = exir_program_aten.module() exir_program_aten__module_quant = _quantize_model( - exir_program_aten_module, calibration_inputs + exir_program_aten.module(), calibration_inputs ) compile_spec = generate_neutron_compile_spec(