Skip to content

Commit f7515a9

Browse files
committed
Add xnnpack pass to propagate custom meta field to q/dq nodes
1 parent 7d8da19 commit f7515a9

File tree

6 files changed

+280
-2
lines changed

6 files changed

+280
-2
lines changed

backends/test/harness/stages/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .export import Export
22
from .partition import Partition
3-
from .quantize import Quantize
3+
from .quantize import Quantize, Quantize_
44
from .run_passes import RunPasses
55
from .serialize import Serialize
66
from .stage import Stage, StageType
@@ -12,6 +12,7 @@
1212
"Export",
1313
"Partition",
1414
"Quantize",
15+
"Quantize_",
1516
"RunPasses",
1617
"Serialize",
1718
"Stage",

backends/test/harness/stages/quantize.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Sequence, Tuple
1+
from typing import Any, Callable, Optional, Sequence, Tuple
22

33
import torch
44

@@ -16,6 +16,14 @@
1616
)
1717
from torchao.quantization.pt2e.quantizer import Quantizer
1818

19+
from torchao.quantization.quant_api import (
20+
Int8DynamicActivationIntxWeightConfig,
21+
IntxWeightOnlyConfig,
22+
quantize_,
23+
)
24+
25+
from torchao.utils import unwrap_tensor_subclass
26+
1927

2028
class Quantize(Stage):
2129
def __init__(
@@ -79,3 +87,48 @@ def graph_module(self) -> str:
7987

8088
def run_artifact(self, inputs):
8189
return self.converted_graph.forward(*inputs)
90+
91+
92+
class Quantize_(Stage):
93+
"""
94+
TorchAO quantization stage using the quantize_ API.
95+
"""
96+
97+
def __init__(
98+
self,
99+
config: Any,
100+
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
101+
):
102+
"""
103+
Args:
104+
config: TorchAO quantization config (e.g., Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig)
105+
filter_fn: Optional filter function to select which modules to quantize
106+
"""
107+
self.config = config
108+
self.filter_fn = filter_fn
109+
self.quantized_module = None
110+
111+
def stage_type(self) -> str:
112+
return StageType.QUANTIZE
113+
114+
def run(
115+
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
116+
) -> None:
117+
# Apply quantize_ to the model
118+
quantize_(artifact, self.config, self.filter_fn)
119+
120+
# Unwrap tensor subclasses for export compatibility
121+
unwrap_tensor_subclass(artifact)
122+
123+
self.quantized_module = artifact
124+
125+
@property
126+
def artifact(self) -> torch.nn.Module:
127+
return self.quantized_module
128+
129+
@property
130+
def graph_module(self) -> torch.nn.Module:
131+
return self.quantized_module
132+
133+
def run_artifact(self, inputs):
134+
return self.quantized_module.forward(*inputs)

backends/test/harness/tester.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Export,
1010
Partition,
1111
Quantize,
12+
Quantize_,
1213
RunPasses,
1314
Serialize,
1415
Stage,

backends/xnnpack/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2424
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
2525
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
26+
from executorch.backends.xnnpack._passes.propagate_custom_meta_pass import (
27+
PropagateCustomMetaPass,
28+
)
2629
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
2730
RemoveRedundantCopyPass,
2831
)
@@ -59,6 +62,7 @@ def __init__(
5962
DimOrderOpsRevertPass,
6063
ConvertToUpsampleBilinear2d,
6164
ConvertToLinearPass,
65+
PropagateCustomMetaPass,
6266
ConvertToSDPAPass,
6367
ConstPropPass,
6468
FuseBatchNormPass,
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
9+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
10+
from executorch.exir.pass_base import PassResult
11+
12+
13+
class PropagateCustomMetaPass(XNNPACKPass):
14+
"""
15+
Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes.
16+
For all quantize/dequantize nodes in the graph, if the parent node has a
17+
node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta.
18+
"""
19+
20+
def call(self, graph_module: torch.fx.GraphModule):
21+
graph = graph_module.graph
22+
23+
for node in graph.nodes:
24+
if not (is_quant(node) or is_dequant(node)):
25+
continue
26+
27+
# Get the parent node (first input argument)
28+
if len(node.all_input_nodes) == 0:
29+
continue
30+
31+
parent_node = node.args[0]
32+
if not isinstance(parent_node, torch.fx.Node):
33+
continue
34+
35+
if "custom" in parent_node.meta:
36+
node.meta["custom"] = parent_node.meta["custom"]
37+
38+
graph_module.recompile()
39+
40+
# Since we are overriding "call", we need to call the parent's "call"
41+
# to retrace the graph and regenerate metadata
42+
graph_module = super().call(graph_module).graph_module
43+
44+
return PassResult(graph_module, True)
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from typing import Callable, Optional, Tuple, Union
10+
11+
import executorch.backends.test.harness.stages as BaseStages
12+
13+
import torch
14+
from executorch.backends.test.harness.stages import StageType
15+
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
16+
ConfigPrecisionType,
17+
)
18+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
19+
XnnpackFloatingPointPartitioner,
20+
XnnpackPartitioner,
21+
)
22+
23+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
24+
get_symmetric_quantization_config,
25+
)
26+
from executorch.backends.xnnpack.test.tester import (
27+
Quantize as XNNPackQuantize,
28+
RunPasses,
29+
Tester,
30+
)
31+
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
32+
from executorch.exir import (
33+
EdgeCompileConfig,
34+
ExecutorchBackendConfig,
35+
to_edge_transform_and_lower,
36+
)
37+
from executorch.exir.passes.external_constants_pass import (
38+
delegate_external_constants_pass_unlifted,
39+
)
40+
from torch.export import export, ExportedProgram
41+
42+
from torchao.quantization.granularity import PerAxis, PerGroup
43+
from torchao.quantization.quant_api import (
44+
Int8DynamicActivationIntxWeightConfig,
45+
IntxWeightOnlyConfig,
46+
quantize_,
47+
)
48+
from torchao.utils import unwrap_tensor_subclass
49+
50+
try:
51+
import executorch.extension.pybindings.portable_lib # noqa[F401]
52+
import executorch.kernels.quantized # noqa[F401]
53+
54+
has_quantized_ops = True
55+
except:
56+
has_quantized_ops = False
57+
print("Missing quantized ops")
58+
59+
60+
class TestPropagateCustomMetaPass(unittest.TestCase):
61+
class ModuleLinear(torch.nn.Module):
62+
def __init__(
63+
self,
64+
in_size: int = 2,
65+
input_channels: int = 4,
66+
output_channels: int = 4,
67+
dtype: torch.dtype = torch.float,
68+
use_bias: bool = False,
69+
):
70+
super().__init__()
71+
self.linear = torch.nn.Linear(
72+
input_channels, output_channels, bias=use_bias
73+
).to(dtype=dtype)
74+
75+
self.ic = input_channels
76+
self.oc = output_channels
77+
assert dtype in [torch.float, torch.half], "Unsupported op dtype"
78+
self.op_dtype = dtype
79+
self.in_size = in_size
80+
81+
def forward(self, x: torch.Tensor):
82+
return self.linear(x)
83+
84+
def get_random_inputs(self):
85+
inp = torch.randn(self.in_size, self.ic).to(self.op_dtype)
86+
return (inp,)
87+
88+
class Export(BaseStages.Export):
89+
def run(
90+
self,
91+
artifact: torch.nn.Module,
92+
inputs: Tuple[torch.Tensor],
93+
) -> None:
94+
95+
tagged_module = export(
96+
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
97+
).module()
98+
delegate_external_constants_pass_unlifted(
99+
module=tagged_module,
100+
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
101+
)
102+
self.exported_program = export(
103+
tagged_module, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
104+
)
105+
106+
def _test_linear(
107+
self,
108+
partitioner: XnnpackPartitioner,
109+
quantization_stage: Union[BaseStages.Quantize, BaseStages.Quantize_],
110+
):
111+
eager_model = self.ModuleLinear(
112+
in_size=1,
113+
input_channels=32,
114+
output_channels=2,
115+
)
116+
test_inputs = eager_model.get_random_inputs()
117+
118+
tester = Tester(eager_model, test_inputs)
119+
tester.quantize(quantization_stage)
120+
tester.export(self.Export())
121+
tester.to_edge_transform_and_lower(
122+
ToEdgeTransformAndLower([partitioner])
123+
).to_executorch()
124+
tester.run_method_and_compare_outputs()
125+
126+
exec = tester.get_artifact()
127+
program_buffer = exec.buffer
128+
self.assertEqual(len(exec._tensor_data), 1)
129+
data_buffer = bytes(exec._tensor_data.pop("model"))
130+
from executorch.extension.pybindings import portable_lib as runtime
131+
132+
module = runtime._load_for_executorch_from_buffer(program_buffer, data_buffer)
133+
output = module.forward(test_inputs)
134+
135+
def test_quantize_(self):
136+
# Quantize with torchao quantize_ API.
137+
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
138+
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
139+
per_op_mode=False,
140+
)
141+
linear_config = Int8DynamicActivationIntxWeightConfig(
142+
weight_dtype=torch.int4,
143+
weight_granularity=PerGroup(32),
144+
)
145+
self._test_linear(
146+
DynamicallyQuantizedPartitioner, BaseStages.Quantize_(config=linear_config)
147+
)
148+
149+
def test_pt2e_quantize(self):
150+
# Quantize with pt2e quantize.
151+
quant_configs = [
152+
# per_channel
153+
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=False),
154+
# per_tensor
155+
get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False),
156+
# per_channel_dynamic
157+
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True),
158+
]
159+
partitioners = []
160+
for config_precision in [
161+
ConfigPrecisionType.STATIC_QUANT,
162+
ConfigPrecisionType.DYNAMIC_QUANT,
163+
]:
164+
for per_op_mode in [True, False]:
165+
partitioners.append(
166+
XnnpackPartitioner(
167+
config_precisions=config_precision,
168+
per_op_mode=per_op_mode,
169+
)
170+
)
171+
for quant_config in quant_configs:
172+
for partitioner in partitioners:
173+
self._test_linear(
174+
partitioner, XNNPackQuantize(quantization_config=quant_config)
175+
)

0 commit comments

Comments
 (0)