Skip to content

Commit 5e68963

Browse files
author
morelos
committed
Update base for Update on "[ET] correcting cpu ref quantize_per_channel logic to align with ATen"
# Context The quantize_per_channel was not perfectly aligned with the ATen implementation, and demonstrated errors when specifying different axis. This bug wasn't distinctly acknowledged given that the test cases only has one test for the whole operator. In order to align more closely with ATen this change simply does a single loop imlpementation with direct channel index calculation over the old `apply_over_dim_list` approach. # Changes We change the core logic for quantize_per_channel to more properly align with ATen's implementation, and we also change it from `apply_over_dim_list` approach to a single loop implementation with direct channel index calculation. This also adds more comprehensive testing for quantize_per_channel so that a bug isn't missed again. Differential Revision: [D77746130](https://our.internmc.facebook.com/intern/diff/D77746130/) [ghstack-poisoned]
2 parents c37ce6c + 51ea3a6 commit 5e68963

File tree

31 files changed

+1695
-308
lines changed

31 files changed

+1695
-308
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pip-out/
2626
*.model
2727
tokenizer.json
2828
*.pte
29+
*.ptd
2930
!test_bpe_tokenizer.bin
3031
!test_tiktoken_tokenizer.model
3132

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
self,
4646
program: ExportedProgram,
4747
delegate_mapping_builder: DelegateMappingBuilder,
48-
downcast_64_bit: bool = False,
48+
downcast_64_bit: bool = True,
4949
) -> None:
5050
self.program = program
5151
self.delegate_mapping_builder = delegate_mapping_builder

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
)
2222
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
2323
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
24-
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
25-
FuseBatchNormWithConvPass,
26-
)
24+
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
2725
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
2826
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
2927

@@ -60,7 +58,7 @@ def __init__(
6058
ConvertToLinearPass,
6159
ConvertToSDPAPass,
6260
ConstPropPass,
63-
FuseBatchNormWithConvPass,
61+
FuseBatchNormPass,
6462
FuseActivationPass,
6563
DecomposeConcatenate,
6664
RemoveGetItemPass,
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import operator
8+
9+
import torch
10+
from executorch.backends.transforms.utils import (
11+
create_constant_placeholder,
12+
delete_constant_placeholder,
13+
)
14+
15+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
16+
17+
from executorch.backends.xnnpack.utils.utils import (
18+
get_param_tensor,
19+
get_tensor_name,
20+
is_param_node,
21+
)
22+
from executorch.exir import ExportedProgram
23+
from executorch.exir.dialects._ops import ops as exir_ops
24+
from executorch.exir.pass_base import PassResult
25+
from torch.export.graph_signature import InputKind
26+
27+
from torch.nn.utils.fusion import fuse_conv_bn_weights, fuse_linear_bn_weights
28+
29+
30+
class FuseBatchNormPass(XNNPACKPass):
31+
"""
32+
BatchNorm can be implemented using 1x1 Depthwise Convolution. However, doing so will increase
33+
memory usage since we serialize new weights to represent the convolution. In most cases,
34+
BatchNorm is used after convolution or linear. The 1x1 depthwise convolution can then be fused
35+
with the previous convolution. For linear cases, BatchNorm can be folded into the previous linear layer.
36+
"""
37+
38+
def call(self, graph_module: torch.fx.GraphModule):
39+
graph = graph_module.graph
40+
constant_placeholders_to_delete = set()
41+
for input_node in graph.nodes:
42+
# We want to discover a chain of conv -> batch_norm or linear -> batch_norm.
43+
# Only proceed if the current node is a conv or linear, and has a single user/successor.
44+
is_conv = input_node.target == exir_ops.edge.aten.convolution.default
45+
is_linear = input_node.target == exir_ops.edge.aten.linear.default
46+
47+
if not (is_conv or is_linear) or len(input_node.users) != 1:
48+
continue
49+
50+
# The single user of the conv or linear node must be batch_norm. If not, bail.
51+
bn = list(input_node.users.keys())[0]
52+
if (
53+
bn.target != exir_ops.edge.aten.native_batch_norm.default
54+
and bn.target
55+
!= exir_ops.edge.aten._native_batch_norm_legit_no_training.default
56+
):
57+
continue
58+
59+
if not self.can_fuse(input_node, bn, self.exported_program):
60+
continue
61+
62+
self._fuse_ops(
63+
graph_module,
64+
graph,
65+
input_node,
66+
bn,
67+
is_conv,
68+
constant_placeholders_to_delete,
69+
)
70+
71+
if len(constant_placeholders_to_delete) > 0:
72+
graph_module.graph.eliminate_dead_code()
73+
for node in constant_placeholders_to_delete:
74+
if (node is not None) and (len(node.users) == 0):
75+
delete_constant_placeholder(self.exported_program, node)
76+
77+
graph_module.recompile()
78+
# To regenerate metadata and shape information, retrace module.
79+
graph_module = super().call(graph_module).graph_module
80+
81+
return PassResult(graph_module, True)
82+
83+
@staticmethod
84+
def can_fuse(
85+
input_node: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram
86+
) -> bool:
87+
"""
88+
Determine whether a BatchNorm node can be fused with the preceding convolution or linear node.
89+
"""
90+
91+
# All users of the batch_norm node must be getitem ops.
92+
# batch_norm returns a 3-element tuple.
93+
# Each user must only access the first element of the tuple.
94+
if [
95+
(user.target == operator.getitem and user.args[1] == 0) for user in bn.users
96+
].count(False):
97+
return False
98+
99+
input_node_weights = input_node.args[1]
100+
bn_weights = bn.args[1]
101+
102+
# Check that the weights for conv or linear and batch_norm are both params.
103+
if not isinstance(input_node_weights, torch.fx.Node) or not isinstance(
104+
bn_weights, torch.fx.Node
105+
):
106+
return False
107+
108+
if [
109+
is_param_node(program, node) for node in {input_node_weights, bn_weights}
110+
].count(False):
111+
return False
112+
113+
return True
114+
115+
def _fuse_ops(
116+
self,
117+
graph_module: torch.fx.GraphModule,
118+
graph: torch.fx.Graph,
119+
input_node: torch.fx.Node,
120+
bn: torch.fx.Node,
121+
is_conv: bool,
122+
constant_placeholders_to_delete: set,
123+
) -> None:
124+
"""
125+
Fuse a BatchNorm node into the preceding convolution or linear node.
126+
Update the fused node's weight and bias, rewire users of the BatchNorm output,
127+
and remove the BatchNorm node.
128+
"""
129+
130+
if is_conv:
131+
assert len(input_node.args) == 9
132+
has_bias_arg = True
133+
else:
134+
# Otherwise, this is a linear node.
135+
# Linear has 2 or 3 args depending on whether bias is used: (input, weight, bias).
136+
assert len(input_node.args) in (2, 3)
137+
has_bias_arg = len(input_node.args) == 3
138+
139+
# Get the weight and bias parameters from the conv or linear op.
140+
input_node_weight = get_param_tensor(self.exported_program, input_node.args[1])
141+
input_node_weight_name = get_tensor_name(
142+
self.exported_program, input_node.args[1]
143+
)
144+
assert input_node_weight is not None
145+
146+
if has_bias_arg:
147+
input_node_bias = get_param_tensor(
148+
self.exported_program, input_node.args[2]
149+
)
150+
input_node_bias_name = get_tensor_name(
151+
self.exported_program, input_node.args[2]
152+
)
153+
else:
154+
input_node_bias = None
155+
input_node_bias_name = ""
156+
157+
# Get the parameters from the batch_norm op.
158+
assert (
159+
bn.target == exir_ops.edge.aten.native_batch_norm.default
160+
and len(bn.args) == 8
161+
) or (
162+
bn.target == exir_ops.edge.aten._native_batch_norm_legit_no_training.default
163+
and len(bn.args) == 7
164+
)
165+
bn_weight = get_param_tensor(self.exported_program, bn.args[1])
166+
bn_bias = get_param_tensor(self.exported_program, bn.args[2])
167+
168+
running_mean = get_param_tensor(self.exported_program, bn.args[3])
169+
assert running_mean is not None
170+
171+
running_var = get_param_tensor(self.exported_program, bn.args[4])
172+
assert running_var is not None
173+
174+
# args[7] for native_batch_norm, but args[6] for
175+
# _native_batch_norm_legit_no_training (which doesn't have training
176+
# as an arg).
177+
eps = bn.args[-1]
178+
179+
# Compute the updated weight and bias after fusing the conv or linear op with the batch_norm op.
180+
fuse_args = (
181+
input_node_weight,
182+
input_node_bias,
183+
running_mean,
184+
running_var,
185+
eps,
186+
bn_weight,
187+
bn_bias,
188+
)
189+
190+
if is_conv:
191+
is_transpose = input_node.args[6]
192+
fused_weight, fused_bias = fuse_conv_bn_weights(*fuse_args, is_transpose)
193+
else:
194+
# Otherwise, this is a linear node.
195+
fused_weight, fused_bias = fuse_linear_bn_weights(*fuse_args)
196+
197+
fused_weight_name = (input_node_weight_name + "_fused_bn").replace(".", "_")
198+
if input_node_bias_name == "":
199+
fused_bias_name = (input_node_weight_name + "_bias_fused_bn").replace(
200+
".", "_"
201+
)
202+
else:
203+
fused_bias_name = (input_node_bias_name + "_fused_bn").replace(".", "_")
204+
205+
# Modify the graph by updating the weight and bias of the conv or linear op
206+
# with the fused weight and bias params, and replacing all the users
207+
# of getitem(batch_norm) with the conv or linear op.
208+
with graph.inserting_before(input_node.args[1]):
209+
fused_op_weight_node = create_constant_placeholder(
210+
exp_program=self.exported_program,
211+
graph=graph_module.graph,
212+
kind=InputKind.PARAMETER,
213+
name=fused_weight_name,
214+
data=fused_weight,
215+
)
216+
if fused_bias is not None:
217+
fused_op_bias_node = create_constant_placeholder(
218+
exp_program=self.exported_program,
219+
graph=graph_module.graph,
220+
kind=InputKind.PARAMETER,
221+
name=fused_bias_name,
222+
data=fused_bias,
223+
)
224+
else:
225+
fused_op_bias_node = None
226+
227+
# Replace the original weight and bias with the fused batch_norm values.
228+
args = list(input_node.args)
229+
args[1] = fused_op_weight_node
230+
231+
if has_bias_arg:
232+
# Overwrite original bias with the fused bias.
233+
args[2] = fused_op_bias_node
234+
elif fused_op_bias_node is not None:
235+
# Add the fused bias as a new argument if no bias had originally existed in the input_node.
236+
args.append(fused_op_bias_node)
237+
238+
input_node.args = tuple(args)
239+
240+
# Remove any use of batch_norm from the graph.
241+
for user in bn.users.copy():
242+
assert user.target == operator.getitem
243+
user.replace_all_uses_with(input_node)
244+
graph.erase_node(user)
245+
246+
graph.erase_node(bn)
247+
constant_placeholders_to_delete.update(input_node.args[1:3] + bn.args[1:5])

0 commit comments

Comments
 (0)