Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
QuantizeFullArgument,
RetraceFoldedDtypesPass,
)
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
FuseQuantizedActivationPass,
)
Expand Down Expand Up @@ -126,6 +127,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertMeanDimToAveragePoolPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(FuseBatchnorm2DPass(exported_program))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only for MI?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already done in the BI case, in prepare_pt2e I believe


self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
Expand Down
128 changes: 128 additions & 0 deletions backends/arm/_passes/fuse_batchnorm2d_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch._export.utils import get_buffer, get_param
from torch.fx import Node
from torch.nn.utils.fusion import fuse_conv_bn_weights


class FuseBatchnorm2DPass(ExportPass):
"""Fuses the pattern convolution -> batchnorm by updating
the weights and bias of the convolution and removing the batchnorm.
"""

def __init__(self, exported_program: ExportedProgram):
self.exported_program = exported_program
super().__init__()

def is_fuseable_conv_bn(self, node: Node):
"""Returns True if node is a batchnorm that can be fused into
a parent convolution."""
if node.op != "call_function":
return False
if node.target not in (
exir_ops.edge.aten._native_batch_norm_legit,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
):
return False
conv = node.all_input_nodes[0]
if conv.target != exir_ops.edge.aten.convolution.default:
return False
# Batchnorm users are getitem, we can only handle those that get first element.
for user in node.users:
get_index = user.args[1]
if get_index != 0:
return False
# Since we change the output of the conv, fuse only if it has single user.
if len(conv.users) > 1:
return False
# For similar reasons, only fuse if conv parameters have single user.
if len(conv.all_input_nodes[1].users) > 1:
return False
if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1:
return False
return True

def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
modified = False
for node in graph_module.graph.nodes:
if not self.is_fuseable_conv_bn(node):
continue

def get_param_or_none(arg) -> torch.nn.Parameter | None:
"""get_param but check if arg is none first."""
return (
get_param(self.exported_program, arg) if arg is not None else None
)

# Get weight, bias, mean, var and epsilon from the batchnorm
bn = node
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5]
bn_weight = get_param_or_none(bn_weight_node)
bn_bias = get_param_or_none(bn_bias_node)

running_mean = get_buffer(self.exported_program, bn_mean_node)
running_var = get_buffer(self.exported_program, bn_var_node)
if running_mean is None or running_var is None:
raise ValueError(
"Parameters running_mean and running_var of batchnorm can't be None."
)
epsilon = bn.args[-1]

# Get weight and bias from conv
conv_weight_node, conv_bias_node = conv.args[1:3]
conv_weight = get_param(self.exported_program, conv_weight_node)
conv_bias = get_param_or_none(conv_bias_node)
if conv_weight is None:
raise ValueError("Parameter weight of convolution can't be None.")

# Compute conv parameters folded with batchnorm
fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights(
conv_weight,
conv_bias,
running_mean,
running_var,
epsilon,
bn_weight,
bn_bias,
)

# Set the conv parameters to fused value
def try_set_param(
param_node: Node | None, param_value: torch.nn.Parameter
) -> bool:
"""set_param but check if param_node is None first. Return True if param was set successfully, otherwise False."""
if param_node is not None:
param_name = (
self.exported_program.graph_signature.inputs_to_parameters[
param_node.name
]
)
self.exported_program.state_dict[param_name] = param_value
return True
return False

try_set_param(conv_weight_node, fused_conv_weight)
if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param(
bn_bias_node, fused_conv_bias
):
# Conv didn't have bias but batchnorm did, steal bias from batchnorm.
conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:])
conv.args = conv_args

# Erasing nodes is handled by dead-code elimination.
for user in bn.users:
user.replace_all_uses_with(conv)
modified = True

if modified:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module=graph_module, modified=modified)
31 changes: 21 additions & 10 deletions backends/arm/test/ops/test_conv_combos.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.exir.backend.backend_details import CompileSpec
from parameterized import parameterized
from torch.nn.parameter import Parameter

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -112,12 +113,16 @@ class ComboConvBatchnormRelu6(torch.nn.Module):
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
]

def __init__(self):
def __init__(self, affine: bool):
super().__init__()
self.conv2d = torch.nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
)
self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=False)
self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine)
self.batch_norm2d.running_mean = torch.rand(3)
self.batch_norm2d.running_var = torch.rand(3)
self.batch_norm2d.weight = Parameter(torch.rand(3))
self.batch_norm2d.bias = Parameter(torch.rand(3))
self.relu6 = torch.nn.ReLU6()

def get_inputs(self) -> Tuple[torch.Tensor]:
Expand Down Expand Up @@ -289,24 +294,30 @@ def test_conv_meandim_u85_BI(self):
##############################
## Conv + batch norm + relu ##
##############################
def test_conv_batchnorm_relu6_tosa_MI(self):
model = ComboConvBatchnormRelu6()
affine_params = [("affine", True), ("_no_affine", False)]

@parameterized.expand(affine_params)
def test_conv_batchnorm_relu6_tosa_MI(self, test_suffix, affine):
model = ComboConvBatchnormRelu6(affine)
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())

def test_conv_batchnorm_relu6_tosa_BI(self):
model = ComboConvBatchnormRelu6()
@parameterized.expand(affine_params)
def test_conv_batchnorm_relu6_tosa_BI(self, test_suffix, affine):
model = ComboConvBatchnormRelu6(affine)
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())

@parameterized.expand(affine_params)
@pytest.mark.corstone_fvp
def test_conv_batchnorm_relu6_u55_BI(self):
model = ComboConvBatchnormRelu6()
def test_conv_batchnorm_relu6_u55_BI(self, test_suffix, affine):
model = ComboConvBatchnormRelu6(affine)
self._test_conv_combo_ethos_BI_pipeline(
model, common.get_u55_compile_spec(), model.get_inputs()
)

@parameterized.expand(affine_params)
@pytest.mark.corstone_fvp
def test_conv_batchnorm_relu_u85_BI(self):
model = ComboConvBatchnormRelu6()
def test_conv_batchnorm_relu_u85_BI(self, test_suffix, affine):
model = ComboConvBatchnormRelu6(affine)
self._test_conv_combo_ethos_BI_pipeline(
model,
common.get_u85_compile_spec(),
Expand Down
44 changes: 44 additions & 0 deletions backends/arm/test/passes/test_cast_int64_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass

from executorch.backends.arm.test import common

from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses


class Int64Model(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what makes it int64?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A scalar always becomes int64


def forward(self, x: torch.Tensor):
return x + 3

def get_inputs(self):
return (torch.rand(4),)


class TestCastInt64Pass(unittest.TestCase):

def test_int64_model(self):
module = Int64Model()
test_pass_stage = RunPasses(passes_with_exported_program=[CastInt64ToInt32Pass])
tester = (
ArmTester(
module,
example_inputs=module.get_inputs(),
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
)
.export()
.to_edge()
.run_passes(test_pass_stage)
.run_method_and_compare_outputs()
)
exported_program = tester.get_artifact("RunPasses").exported_program()
for state in exported_program.state_dict:
assert exported_program.state_dict[state].dtype != torch.int64
Loading
Loading