Skip to content

Commit fa97c1b

Browse files
committed
Add xnnpack pass to propagate custom meta field to q/dq nodes
1 parent 7d8da19 commit fa97c1b

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2424
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
2525
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
26+
from executorch.backends.xnnpack._passes.propagate_custom_meta_pass import PropagateCustomMetaPass
2627
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
2728
RemoveRedundantCopyPass,
2829
)
@@ -59,6 +60,7 @@ def __init__(
5960
DimOrderOpsRevertPass,
6061
ConvertToUpsampleBilinear2d,
6162
ConvertToLinearPass,
63+
PropagateCustomMetaPass,
6264
ConvertToSDPAPass,
6365
ConstPropPass,
6466
FuseBatchNormPass,
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
9+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
10+
from executorch.exir.pass_base import PassResult
11+
12+
13+
class PropagateCustomMetaPass(XNNPACKPass):
14+
"""
15+
Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes.
16+
For all quantize/dequantize nodes in the graph, if the parent node has a
17+
node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta.
18+
"""
19+
20+
def call(self, graph_module: torch.fx.GraphModule):
21+
graph = graph_module.graph
22+
23+
for node in graph.nodes:
24+
if not (is_quant(node) or is_dequant(node)):
25+
continue
26+
27+
# Get the parent node (first input argument)
28+
if len(node.all_input_nodes) == 0:
29+
continue
30+
31+
parent_node = node.args[0]
32+
if not isinstance(parent_node, torch.fx.Node):
33+
continue
34+
35+
if "custom" in parent_node.meta:
36+
node.meta["custom"] = parent_node.meta["custom"]
37+
38+
graph_module.recompile()
39+
40+
# Since we are overriding "call", we need to call the parent's "call"
41+
# to retrace the graph and regenerate metadata
42+
graph_module = super().call(graph_module).graph_module
43+
44+
return PassResult(graph_module, True)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from torchao.quantization.granularity import PerGroup, PerAxis
12+
from torchao.quantization.quant_api import (
13+
IntxWeightOnlyConfig,
14+
Int8DynamicActivationIntxWeightConfig,
15+
quantize_,
16+
)
17+
from torchao.utils import unwrap_tensor_subclass
18+
from torch.export import export, ExportedProgram
19+
20+
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
21+
ConfigPrecisionType,
22+
)
23+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
24+
XnnpackFloatingPointPartitioner,
25+
XnnpackPartitioner,
26+
)
27+
from executorch.exir import (
28+
EdgeCompileConfig,
29+
ExecutorchBackendConfig,
30+
to_edge_transform_and_lower,
31+
)
32+
33+
from executorch.exir.passes.external_constants_pass import (
34+
delegate_external_constants_pass_unlifted,
35+
)
36+
class TestPropagateCustomMetaPass(unittest.TestCase):
37+
class ModuleLinear(torch.nn.Module):
38+
def __init__(
39+
self,
40+
in_size: int = 2,
41+
input_channels: int = 4,
42+
output_channels: int = 4,
43+
dtype: torch.dtype = torch.float,
44+
use_bias: bool = False
45+
):
46+
super().__init__()
47+
self.linear = torch.nn.Linear(
48+
input_channels, output_channels, bias=use_bias
49+
).to(dtype=dtype)
50+
51+
self.ic = input_channels
52+
self.oc = output_channels
53+
assert dtype in [torch.float, torch.half], "Unsupported op dtype"
54+
self.op_dtype = dtype
55+
self.in_size = in_size
56+
57+
def forward(self, x: torch.Tensor):
58+
return self.linear(x)
59+
60+
def get_random_inputs(self):
61+
inp = torch.randn(self.in_size, self.ic).to(self.op_dtype)
62+
return (inp,)
63+
64+
def test_propagate_custom_meta_pass(self):
65+
eager_model = self.ModuleLinear(
66+
in_size=1,
67+
input_channels=32,
68+
output_channels=2,
69+
)
70+
test_inputs = eager_model.get_random_inputs()
71+
eager_result = eager_model(*test_inputs)
72+
73+
# Quantize with torchao quantize_ API.
74+
linear_config = Int8DynamicActivationIntxWeightConfig(
75+
weight_dtype=torch.int4,
76+
weight_granularity=PerGroup(32),
77+
)
78+
quantize_(eager_model, linear_config)
79+
quantized_result = eager_model(*test_inputs)
80+
unwrap_tensor_subclass(eager_model)
81+
82+
# Tag the unlifted ep.module().
83+
tagged_module = export(eager_model, test_inputs, dynamic_shapes=None, strict=True).module()
84+
delegate_external_constants_pass_unlifted(
85+
module=tagged_module,
86+
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
87+
)
88+
89+
ep = export(tagged_module, test_inputs, dynamic_shapes=None, strict=True)
90+
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
91+
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
92+
per_op_mode=True,
93+
)
94+
edge = to_edge_transform_and_lower(
95+
ep,
96+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
97+
partitioner=[XnnpackPartitioner()],
98+
generate_etrecord=False,
99+
)
100+
exec = edge.to_executorch(ExecutorchBackendConfig())
101+
102+
program_buffer = exec.buffer
103+
data_buffer = bytes(exec._tensor_data.pop("model"))
104+
105+
from executorch.extension.pybindings import portable_lib as runtime
106+
module = runtime._load_for_executorch_from_buffer(program_buffer, data_buffer)
107+
output = module.forward(test_inputs)

0 commit comments

Comments
 (0)