-
Notifications
You must be signed in to change notification settings - Fork 748
Add fuse batchnorm pass #8028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add fuse batchnorm pass #8028
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| # Copyright 2025 Arm Limited and/or its affiliates. | ||
zingo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import torch | ||
| from executorch.exir import ExportedProgram | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
| from executorch.exir.pass_base import ExportPass, PassResult | ||
| from torch._export.utils import get_buffer, get_param | ||
| from torch.fx import Node | ||
| from torch.nn.utils.fusion import fuse_conv_bn_weights | ||
|
|
||
|
|
||
| class FuseBatchnorm2DPass(ExportPass): | ||
| """Fuses the pattern convolution -> batchnorm by updating | ||
| the weights and bias of the convolution and removing the batchnorm. | ||
| """ | ||
|
|
||
| def __init__(self, exported_program: ExportedProgram): | ||
| self.exported_program = exported_program | ||
| super().__init__() | ||
|
|
||
| def is_fuseable_conv_bn(self, node: Node): | ||
| """Returns True if node is a batchnorm that can be fused into | ||
| a parent convolution.""" | ||
| if node.op != "call_function": | ||
| return False | ||
| if node.target not in ( | ||
| exir_ops.edge.aten._native_batch_norm_legit, | ||
| exir_ops.edge.aten._native_batch_norm_legit_no_training.default, | ||
| ): | ||
| return False | ||
| conv = node.all_input_nodes[0] | ||
| if conv.target != exir_ops.edge.aten.convolution.default: | ||
| return False | ||
| # Batchnorm users are getitem, we can only handle those that get first element. | ||
| for user in node.users: | ||
| get_index = user.args[1] | ||
| if get_index != 0: | ||
| return False | ||
| # Since we change the output of the conv, fuse only if it has single user. | ||
| if len(conv.users) > 1: | ||
| return False | ||
| # For similar reasons, only fuse if conv parameters have single user. | ||
| if len(conv.all_input_nodes[1].users) > 1: | ||
| return False | ||
| if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1: | ||
| return False | ||
| return True | ||
|
|
||
| def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 | ||
| modified = False | ||
| for node in graph_module.graph.nodes: | ||
| if not self.is_fuseable_conv_bn(node): | ||
| continue | ||
|
|
||
| def get_param_or_none(arg) -> torch.nn.Parameter | None: | ||
| """get_param but check if arg is none first.""" | ||
| return ( | ||
| get_param(self.exported_program, arg) if arg is not None else None | ||
| ) | ||
|
|
||
| # Get weight, bias, mean, var and epsilon from the batchnorm | ||
| bn = node | ||
| conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5] | ||
| bn_weight = get_param_or_none(bn_weight_node) | ||
| bn_bias = get_param_or_none(bn_bias_node) | ||
|
|
||
| running_mean = get_buffer(self.exported_program, bn_mean_node) | ||
| running_var = get_buffer(self.exported_program, bn_var_node) | ||
| if running_mean is None or running_var is None: | ||
| raise ValueError( | ||
| "Parameters running_mean and running_var of batchnorm can't be None." | ||
| ) | ||
| epsilon = bn.args[-1] | ||
|
|
||
| # Get weight and bias from conv | ||
| conv_weight_node, conv_bias_node = conv.args[1:3] | ||
| conv_weight = get_param(self.exported_program, conv_weight_node) | ||
| conv_bias = get_param_or_none(conv_bias_node) | ||
| if conv_weight is None: | ||
| raise ValueError("Parameter weight of convolution can't be None.") | ||
|
|
||
| # Compute conv parameters folded with batchnorm | ||
| fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights( | ||
| conv_weight, | ||
| conv_bias, | ||
| running_mean, | ||
| running_var, | ||
| epsilon, | ||
| bn_weight, | ||
| bn_bias, | ||
| ) | ||
|
|
||
| # Set the conv parameters to fused value | ||
| def try_set_param( | ||
| param_node: Node | None, param_value: torch.nn.Parameter | ||
| ) -> bool: | ||
| """set_param but check if param_node is None first. Return True if param was set successfully, otherwise False.""" | ||
| if param_node is not None: | ||
| param_name = ( | ||
| self.exported_program.graph_signature.inputs_to_parameters[ | ||
| param_node.name | ||
| ] | ||
| ) | ||
| self.exported_program.state_dict[param_name] = param_value | ||
| return True | ||
| return False | ||
|
|
||
| try_set_param(conv_weight_node, fused_conv_weight) | ||
| if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param( | ||
| bn_bias_node, fused_conv_bias | ||
| ): | ||
| # Conv didn't have bias but batchnorm did, steal bias from batchnorm. | ||
| conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:]) | ||
| conv.args = conv_args | ||
|
|
||
| # Erasing nodes is handled by dead-code elimination. | ||
| for user in bn.users: | ||
| user.replace_all_uses_with(conv) | ||
| modified = True | ||
|
|
||
| if modified: | ||
| graph_module.graph.eliminate_dead_code() | ||
| graph_module.recompile() | ||
| graph_module = super().call(graph_module).graph_module | ||
| return PassResult(graph_module=graph_module, modified=modified) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| # Copyright 2025 Arm Limited and/or its affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass | ||
|
|
||
| from executorch.backends.arm.test import common | ||
|
|
||
| from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses | ||
|
|
||
|
|
||
| class Int64Model(torch.nn.Module): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what makes it int64?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A scalar always becomes int64 |
||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| return x + 3 | ||
|
|
||
| def get_inputs(self): | ||
| return (torch.rand(4),) | ||
|
|
||
|
|
||
| class TestCastInt64Pass(unittest.TestCase): | ||
|
|
||
| def test_int64_model(self): | ||
| module = Int64Model() | ||
| test_pass_stage = RunPasses(passes_with_exported_program=[CastInt64ToInt32Pass]) | ||
| tester = ( | ||
| ArmTester( | ||
| module, | ||
| example_inputs=module.get_inputs(), | ||
| compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), | ||
| ) | ||
| .export() | ||
| .to_edge() | ||
| .run_passes(test_pass_stage) | ||
| .run_method_and_compare_outputs() | ||
| ) | ||
| exported_program = tester.get_artifact("RunPasses").exported_program() | ||
| for state in exported_program.state_dict: | ||
| assert exported_program.state_dict[state].dtype != torch.int64 | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only for MI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already done in the BI case, in
prepare_pt2eI believe