Skip to content

Commit bc8f182

Browse files
committed
Add xnnpack pass to propagate custom meta field to q/dq nodes
1 parent 7d8da19 commit bc8f182

File tree

3 files changed

+160
-0
lines changed

3 files changed

+160
-0
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2424
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
2525
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
26+
from executorch.backends.xnnpack._passes.propagate_custom_meta_pass import (
27+
PropagateCustomMetaPass,
28+
)
2629
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
2730
RemoveRedundantCopyPass,
2831
)
@@ -59,6 +62,7 @@ def __init__(
5962
DimOrderOpsRevertPass,
6063
ConvertToUpsampleBilinear2d,
6164
ConvertToLinearPass,
65+
PropagateCustomMetaPass,
6266
ConvertToSDPAPass,
6367
ConstPropPass,
6468
FuseBatchNormPass,
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
9+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
10+
from executorch.exir.pass_base import PassResult
11+
12+
13+
class PropagateCustomMetaPass(XNNPACKPass):
14+
"""
15+
Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes.
16+
For all quantize/dequantize nodes in the graph, if the parent node has a
17+
node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta.
18+
"""
19+
20+
def call(self, graph_module: torch.fx.GraphModule):
21+
graph = graph_module.graph
22+
23+
for node in graph.nodes:
24+
if not (is_quant(node) or is_dequant(node)):
25+
continue
26+
27+
# Get the parent node (first input argument)
28+
if len(node.all_input_nodes) == 0:
29+
continue
30+
31+
parent_node = node.args[0]
32+
if not isinstance(parent_node, torch.fx.Node):
33+
continue
34+
35+
if "custom" in parent_node.meta:
36+
node.meta["custom"] = parent_node.meta["custom"]
37+
38+
graph_module.recompile()
39+
40+
# Since we are overriding "call", we need to call the parent's "call"
41+
# to retrace the graph and regenerate metadata
42+
graph_module = super().call(graph_module).graph_module
43+
44+
return PassResult(graph_module, True)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
12+
ConfigPrecisionType,
13+
)
14+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
15+
XnnpackFloatingPointPartitioner,
16+
XnnpackPartitioner,
17+
)
18+
from executorch.exir import (
19+
EdgeCompileConfig,
20+
ExecutorchBackendConfig,
21+
to_edge_transform_and_lower,
22+
)
23+
24+
from executorch.exir.passes.external_constants_pass import (
25+
delegate_external_constants_pass_unlifted,
26+
)
27+
from torch.export import export, ExportedProgram
28+
29+
from torchao.quantization.granularity import PerAxis, PerGroup
30+
from torchao.quantization.quant_api import (
31+
Int8DynamicActivationIntxWeightConfig,
32+
IntxWeightOnlyConfig,
33+
quantize_,
34+
)
35+
from torchao.utils import unwrap_tensor_subclass
36+
37+
38+
class TestPropagateCustomMetaPass(unittest.TestCase):
39+
class ModuleLinear(torch.nn.Module):
40+
def __init__(
41+
self,
42+
in_size: int = 2,
43+
input_channels: int = 4,
44+
output_channels: int = 4,
45+
dtype: torch.dtype = torch.float,
46+
use_bias: bool = False,
47+
):
48+
super().__init__()
49+
self.linear = torch.nn.Linear(
50+
input_channels, output_channels, bias=use_bias
51+
).to(dtype=dtype)
52+
53+
self.ic = input_channels
54+
self.oc = output_channels
55+
assert dtype in [torch.float, torch.half], "Unsupported op dtype"
56+
self.op_dtype = dtype
57+
self.in_size = in_size
58+
59+
def forward(self, x: torch.Tensor):
60+
return self.linear(x)
61+
62+
def get_random_inputs(self):
63+
inp = torch.randn(self.in_size, self.ic).to(self.op_dtype)
64+
return (inp,)
65+
66+
def test_propagate_custom_meta_pass(self):
67+
eager_model = self.ModuleLinear(
68+
in_size=1,
69+
input_channels=32,
70+
output_channels=2,
71+
)
72+
test_inputs = eager_model.get_random_inputs()
73+
eager_result = eager_model(*test_inputs)
74+
75+
# Quantize with torchao quantize_ API.
76+
linear_config = Int8DynamicActivationIntxWeightConfig(
77+
weight_dtype=torch.int4,
78+
weight_granularity=PerGroup(32),
79+
)
80+
quantize_(eager_model, linear_config)
81+
quantized_result = eager_model(*test_inputs)
82+
unwrap_tensor_subclass(eager_model)
83+
84+
# Tag the unlifted ep.module().
85+
tagged_module = export(
86+
eager_model, test_inputs, dynamic_shapes=None, strict=True
87+
).module()
88+
delegate_external_constants_pass_unlifted(
89+
module=tagged_module,
90+
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
91+
)
92+
93+
ep = export(tagged_module, test_inputs, dynamic_shapes=None, strict=True)
94+
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
95+
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
96+
per_op_mode=True,
97+
)
98+
edge = to_edge_transform_and_lower(
99+
ep,
100+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
101+
partitioner=[XnnpackPartitioner()],
102+
generate_etrecord=False,
103+
)
104+
exec = edge.to_executorch(ExecutorchBackendConfig())
105+
106+
program_buffer = exec.buffer
107+
data_buffer = bytes(exec._tensor_data.pop("model"))
108+
109+
from executorch.extension.pybindings import portable_lib as runtime
110+
111+
module = runtime._load_for_executorch_from_buffer(program_buffer, data_buffer)
112+
output = module.forward(test_inputs)

0 commit comments

Comments
 (0)