Skip to content

Commit 54acf1b

Browse files
committed
[ARM backend] Update fuse_batchnorm_pass to create new placeholders
- This allows to fuse bn+convs with multiple users of the same weights - Adds new util functions create/delete_const_placeholders to take care of updating the GraphSignature and state_dict/constants dict when handling constant placholders. - Adds and updates related tests Change-Id: I8e550614d9741de840786d9dca9f30af9eb95a64
1 parent 3c378dd commit 54acf1b

File tree

4 files changed

+305
-53
lines changed

4 files changed

+305
-53
lines changed

backends/arm/_passes/arm_pass_utils.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -25,7 +25,13 @@
2525
is_param,
2626
)
2727
from torch._ops import OpOverload
28-
from torch._subclasses.fake_tensor import FakeTensor
28+
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter
29+
from torch.export.graph_signature import (
30+
ExportGraphSignature,
31+
InputKind,
32+
InputSpec,
33+
TensorArgument,
34+
)
2935

3036

3137
def is_get_attr_node(node: torch.fx.Node) -> bool:
@@ -64,6 +70,124 @@ def get_param_tensor(
6470
raise RuntimeError(f"unsupported param type, {node.op}.")
6571

6672

73+
def create_constant_placeholder(
74+
exp_program: ExportedProgram,
75+
graph: torch.fx.Graph,
76+
name: str,
77+
kind: InputKind,
78+
data: torch.Tensor,
79+
persistent_buffer: Optional[bool] = None,
80+
) -> torch.fx.Node:
81+
"""
82+
Creates and returns a constant placeholder node, meaning that it is of type parameter, buffer, or lifted constant tensor.
83+
graph.inserting_before/after() should be used before the call to decide where to insert the node.
84+
"""
85+
86+
target = name
87+
88+
# Add data to state_dict/ constants
89+
match kind:
90+
case InputKind.PARAMETER:
91+
exp_program.state_dict[target] = torch.nn.Parameter(
92+
data, requires_grad=False
93+
)
94+
case InputKind.BUFFER:
95+
if persistent_buffer is None:
96+
raise RuntimeError(
97+
"Must set persistent_buffer when creating a new buffer."
98+
)
99+
elif persistent_buffer:
100+
exp_program.state_dict[target] = data
101+
else:
102+
exp_program.constants[target] = data
103+
104+
case InputKind.CONSTANT_TENSOR:
105+
exp_program.constants[target] = data
106+
case _:
107+
raise RuntimeError("Can only create constant input nodes.")
108+
109+
# Create node
110+
fake_tensor_mode = get_first_fake_tensor(
111+
list(graph.nodes)[0]
112+
).fake_mode # Use the same fake_tensor_mode as all other fake tensors in the graph
113+
node = graph.create_node(op="placeholder", name=name, target=name)
114+
node.meta["val"] = FakeTensorConverter().from_real_tensor(fake_tensor_mode, t=data)
115+
116+
# Add tensor to graph_signature in the same order as nodes in the graph
117+
node_names = [n.name for n in graph.nodes if n.op == "placeholder"]
118+
node_index = node_names.index(name)
119+
120+
input_specs = exp_program.graph_signature.input_specs
121+
user_input_indices = [
122+
i for i, spec in enumerate(input_specs) if spec.kind == InputKind.USER_INPUT
123+
]
124+
if not all(
125+
(user_input_index > node_index for user_input_index in user_input_indices)
126+
):
127+
raise RuntimeError(
128+
f"Failed to insert {name}; Const placeholder nodes must be inserted before user input nodes in the graph."
129+
)
130+
131+
arg_spec = TensorArgument(name)
132+
input_spec = InputSpec(kind, arg_spec, target, persistent_buffer)
133+
input_specs.insert(node_index, input_spec)
134+
135+
new_graph_signature = ExportGraphSignature(
136+
input_specs, exp_program.graph_signature.output_specs
137+
)
138+
exp_program._graph_signature = new_graph_signature
139+
140+
return node
141+
142+
143+
def delete_constant_placeholder(exp_program: ExportedProgram, node: torch.fx.Node):
144+
"""
145+
Deletes a constant placeholder node, meaning that it is of type parameter, buffer, or lifted constant tensor,
146+
if the node does not have any users.
147+
"""
148+
if not len(node.users) == 0:
149+
raise RuntimeError(
150+
f"Cannot delete input node {node.name} since it has users in the graph."
151+
)
152+
153+
# Remove tensor from state_dict/ constants
154+
if node.name in exp_program.graph_signature.inputs_to_parameters:
155+
target = exp_program.graph_signature.inputs_to_parameters[node.name]
156+
del exp_program.state_dict[target]
157+
158+
elif node.name in exp_program.graph_signature.inputs_to_buffers:
159+
target = exp_program.graph_signature.inputs_to_buffers[node.name]
160+
161+
if target in exp_program.graph_signature.non_persistent_buffers:
162+
del exp_program.constants[target]
163+
else:
164+
del exp_program.state_dict[target]
165+
166+
elif node.name in exp_program.graph_signature.inputs_to_lifted_tensor_constants:
167+
target = exp_program.graph_signature.inputs_to_lifted_tensor_constants[
168+
node.name
169+
]
170+
del exp_program.constants[target]
171+
else:
172+
raise RuntimeError(
173+
f"Cannot delete input node {node.name} since it is not a parameter, a buffer, nor a lifted tensor constant."
174+
)
175+
176+
# Remove input from graph signature
177+
input_specs = [
178+
spec
179+
for spec in exp_program.graph_signature.input_specs
180+
if spec.arg.name != node.name
181+
]
182+
new_graph_signature = ExportGraphSignature(
183+
input_specs, exp_program.graph_signature.output_specs
184+
)
185+
exp_program._graph_signature = new_graph_signature
186+
187+
# Remove node from graph
188+
node.graph.erase_node(node)
189+
190+
67191
def create_node(
68192
graph: torch.fx.Graph,
69193
op_target: OpOverload,

backends/arm/_passes/fuse_batchnorm2d_pass.py

Lines changed: 82 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
# pyre-unsafe
77

88
import torch
9+
from executorch.backends.arm._passes.arm_pass_utils import (
10+
create_constant_placeholder,
11+
delete_constant_placeholder,
12+
)
913
from executorch.exir import ExportedProgram
1014
from executorch.exir.dialects._ops import ops as exir_ops
1115
from executorch.exir.pass_base import ExportPass, PassResult
1216
from torch._export.utils import get_buffer, get_param
17+
from torch.export.graph_signature import InputKind
1318
from torch.fx import Node
1419
from torch.nn.utils.fusion import fuse_conv_bn_weights
1520

@@ -23,7 +28,7 @@ def __init__(self, exported_program: ExportedProgram):
2328
self.exported_program = exported_program
2429
super().__init__()
2530

26-
def is_fuseable_conv_bn(self, node: Node):
31+
def is_fuseable_conv_bn(self, node: Node) -> bool:
2732
"""Returns True if node is a batchnorm that can be fused into
2833
a parent convolution."""
2934
if node.op != "call_function":
@@ -44,15 +49,19 @@ def is_fuseable_conv_bn(self, node: Node):
4449
# Since we change the output of the conv, fuse only if it has single user.
4550
if len(conv.users) > 1:
4651
return False
47-
# For similar reasons, only fuse if conv parameters have single user.
48-
if len(conv.all_input_nodes[1].users) > 1:
49-
return False
50-
if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1:
51-
return False
5252
return True
5353

54+
def get_bias_name(self, conv_weight_node: Node, conv_bias_node: Node) -> str:
55+
if conv_bias_node:
56+
return conv_bias_node.name + "_fused_bn"
57+
elif "weight" in conv_weight_node.name:
58+
return conv_weight_node.name.replace("weight", "bias") + "_fused_bn"
59+
else:
60+
return conv_weight_node.name + "_bias_fused_bn"
61+
5462
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
5563
modified = False
64+
constant_placeholders_to_delete = set()
5665
for node in graph_module.graph.nodes:
5766
if not self.is_fuseable_conv_bn(node):
5867
continue
@@ -64,67 +73,93 @@ def get_param_or_none(arg) -> torch.nn.Parameter | None:
6473
)
6574

6675
# Get weight, bias, mean, var and epsilon from the batchnorm
67-
bn = node
68-
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5]
69-
bn_weight = get_param_or_none(bn_weight_node)
70-
bn_bias = get_param_or_none(bn_bias_node)
71-
72-
running_mean = get_buffer(self.exported_program, bn_mean_node)
73-
running_var = get_buffer(self.exported_program, bn_var_node)
74-
if running_mean is None or running_var is None:
76+
bn_node = node
77+
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = (
78+
bn_node.args[0:5]
79+
)
80+
bn_weight_tensor = get_param_or_none(bn_weight_node)
81+
bn_bias_tensor = get_param_or_none(bn_bias_node)
82+
bn_mean_tensor = get_buffer(self.exported_program, bn_mean_node)
83+
bn_var_tensor = get_buffer(self.exported_program, bn_var_node)
84+
if bn_mean_tensor is None or bn_var_tensor is None:
7585
raise ValueError(
7686
"Parameters running_mean and running_var of batchnorm can't be None."
7787
)
78-
epsilon = bn.args[-1]
88+
epsilon = bn_node.args[-1]
7989

8090
# Get weight and bias from conv
8191
conv_weight_node, conv_bias_node = conv.args[1:3]
82-
conv_weight = get_param(self.exported_program, conv_weight_node)
83-
conv_bias = get_param_or_none(conv_bias_node)
84-
if conv_weight is None:
92+
conv_weight_tensor = get_param(self.exported_program, conv_weight_node)
93+
conv_bias_tensor = get_param_or_none(conv_bias_node)
94+
if conv_weight_tensor is None:
8595
raise ValueError("Parameter weight of convolution can't be None.")
8696

8797
# Compute conv parameters folded with batchnorm
8898
fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights(
89-
conv_weight,
90-
conv_bias,
91-
running_mean,
92-
running_var,
99+
conv_weight_tensor,
100+
conv_bias_tensor,
101+
bn_mean_tensor,
102+
bn_var_tensor,
93103
epsilon,
94-
bn_weight,
95-
bn_bias,
104+
bn_weight_tensor,
105+
bn_bias_tensor,
96106
)
97107

98-
# Set the conv parameters to fused value
99-
def try_set_param(
100-
param_node: Node | None, param_value: torch.nn.Parameter
101-
) -> bool:
102-
"""set_param but check if param_node is None first. Return True if param was set successfully, otherwise False."""
103-
if param_node is not None:
104-
param_name = (
105-
self.exported_program.graph_signature.inputs_to_parameters[
106-
param_node.name
107-
]
108-
)
109-
self.exported_program.state_dict[param_name] = param_value
110-
return True
111-
return False
108+
# Create fused weights and bias to conv and replace conv args
109+
with graph_module.graph.inserting_before(conv_weight_node):
110+
fused_conv_weight_node = create_constant_placeholder(
111+
exp_program=self.exported_program,
112+
graph=graph_module.graph,
113+
kind=InputKind.PARAMETER,
114+
name=conv_weight_node.name + "_fused_bn",
115+
data=fused_conv_weight,
116+
)
112117

113-
try_set_param(conv_weight_node, fused_conv_weight)
114-
if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param(
115-
bn_bias_node, fused_conv_bias
116-
):
117-
# Conv didn't have bias but batchnorm did, steal bias from batchnorm.
118-
conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:])
119-
conv.args = conv_args
118+
if fused_conv_bias is not None:
119+
fused_conv_bias_node = create_constant_placeholder(
120+
exp_program=self.exported_program,
121+
graph=graph_module.graph,
122+
kind=InputKind.PARAMETER,
123+
name=self.get_bias_name(conv_weight_node, conv_bias_node),
124+
data=fused_conv_bias,
125+
)
126+
else:
127+
fused_conv_bias_node = None
128+
129+
conv.args = (
130+
conv.args[0],
131+
fused_conv_weight_node,
132+
fused_conv_bias_node,
133+
*conv.args[3:],
134+
)
120135

121-
# Erasing nodes is handled by dead-code elimination.
122-
for user in bn.users:
136+
# Erasing batch-norm nodes is handled by dead-code elimination. After that we may remove their constant placeholder inputs
137+
for user in bn_node.users:
123138
user.replace_all_uses_with(conv)
139+
140+
constant_placeholders_to_delete.update(
141+
[
142+
bn_weight_node,
143+
bn_bias_node,
144+
bn_mean_node,
145+
bn_var_node,
146+
conv_weight_node,
147+
conv_bias_node,
148+
]
149+
)
124150
modified = True
125151

126152
if modified:
127153
graph_module.graph.eliminate_dead_code()
154+
for constant_placeholder in constant_placeholders_to_delete:
155+
if (constant_placeholder is not None) and (
156+
len(constant_placeholder.users) == 0
157+
):
158+
delete_constant_placeholder(
159+
self.exported_program, constant_placeholder
160+
)
161+
128162
graph_module.recompile()
129163
graph_module = super().call(graph_module).graph_module
164+
130165
return PassResult(graph_module=graph_module, modified=modified)

0 commit comments

Comments
 (0)