Skip to content

Commit db612db

Browse files
committed
Add xnnpack pass to propagate custom meta field to q/dq nodes
1 parent d00279d commit db612db

File tree

5 files changed

+259
-2
lines changed

5 files changed

+259
-2
lines changed

backends/test/harness/stages/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .export import Export
22
from .partition import Partition
3-
from .quantize import Quantize
3+
from .quantize import Quantize, Quantize_
44
from .run_passes import RunPasses
55
from .serialize import Serialize
66
from .stage import Stage, StageType
@@ -12,6 +12,7 @@
1212
"Export",
1313
"Partition",
1414
"Quantize",
15+
"Quantize_",
1516
"RunPasses",
1617
"Serialize",
1718
"Stage",

backends/test/harness/stages/quantize.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Sequence, Tuple
1+
from typing import Any, Callable, Optional, Sequence, Tuple
22

33
import torch
44

@@ -15,6 +15,8 @@
1515
prepare_qat_pt2e,
1616
)
1717
from torchao.quantization.pt2e.quantizer import Quantizer
18+
from torchao.quantization.quant_api import quantize_
19+
from torchao.utils import unwrap_tensor_subclass
1820

1921

2022
class Quantize(Stage):
@@ -79,3 +81,48 @@ def graph_module(self) -> str:
7981

8082
def run_artifact(self, inputs):
8183
return self.converted_graph.forward(*inputs)
84+
85+
86+
class Quantize_(Stage):
87+
"""
88+
TorchAO quantization stage using the quantize_ API.
89+
"""
90+
91+
def __init__(
92+
self,
93+
config: Any,
94+
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
95+
):
96+
"""
97+
Args:
98+
config: TorchAO quantization config (e.g., Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig)
99+
filter_fn: Optional filter function to select which modules to quantize
100+
"""
101+
self.config = config
102+
self.filter_fn = filter_fn
103+
self.quantized_module = None
104+
105+
def stage_type(self) -> str:
106+
return StageType.QUANTIZE
107+
108+
def run(
109+
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
110+
) -> None:
111+
# Apply quantize_ to the model
112+
quantize_(artifact, self.config, self.filter_fn)
113+
114+
# Unwrap tensor subclasses for export compatibility
115+
unwrap_tensor_subclass(artifact)
116+
117+
self.quantized_module = artifact
118+
119+
@property
120+
def artifact(self) -> torch.nn.Module:
121+
return self.quantized_module
122+
123+
@property
124+
def graph_module(self) -> torch.nn.Module:
125+
return self.quantized_module
126+
127+
def run_artifact(self, inputs):
128+
return self.quantized_module.forward(*inputs)

backends/xnnpack/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2424
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
2525
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
26+
from executorch.backends.xnnpack._passes.propagate_custom_meta_pass import (
27+
PropagateCustomMetaPass,
28+
)
2629
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
2730
RemoveRedundantCopyPass,
2831
)
@@ -59,6 +62,7 @@ def __init__(
5962
DimOrderOpsRevertPass,
6063
ConvertToUpsampleBilinear2d,
6164
ConvertToLinearPass,
65+
PropagateCustomMetaPass,
6266
ConvertToSDPAPass,
6367
ConstPropPass,
6468
FuseBatchNormPass,
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
9+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
10+
from executorch.exir.pass_base import PassResult
11+
12+
13+
class PropagateCustomMetaPass(XNNPACKPass):
14+
"""
15+
Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes.
16+
For all quantize/dequantize nodes in the graph, if the parent node has a
17+
node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta.
18+
"""
19+
20+
def call(self, graph_module: torch.fx.GraphModule):
21+
graph = graph_module.graph
22+
23+
for node in graph.nodes:
24+
if not (is_quant(node) or is_dequant(node)):
25+
continue
26+
27+
# Get the parent node (first input argument)
28+
if len(node.all_input_nodes) == 0:
29+
continue
30+
31+
parent_node = node.args[0]
32+
if not isinstance(parent_node, torch.fx.Node):
33+
continue
34+
35+
if "custom" in parent_node.meta:
36+
node.meta["custom"] = parent_node.meta["custom"]
37+
38+
graph_module.recompile()
39+
40+
# Since we are overriding "call", we need to call the parent's "call"
41+
# to retrace the graph and regenerate metadata
42+
graph_module = super().call(graph_module).graph_module
43+
44+
return PassResult(graph_module, True)
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from typing import Tuple, Union
10+
11+
import executorch.backends.test.harness.stages as BaseStages
12+
13+
import torch
14+
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
15+
ConfigPrecisionType,
16+
)
17+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
18+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
19+
get_symmetric_quantization_config,
20+
)
21+
from executorch.backends.xnnpack.test.tester import Quantize as XNNPackQuantize, Tester
22+
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
23+
from executorch.exir.passes.external_constants_pass import (
24+
delegate_external_constants_pass_unlifted,
25+
)
26+
27+
from torchao.quantization.granularity import PerGroup
28+
from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig
29+
30+
try:
31+
import executorch.extension.pybindings.portable_lib # noqa[F401]
32+
import executorch.kernels.quantized # noqa[F401]
33+
34+
has_quantized_ops = True
35+
except:
36+
has_quantized_ops = False
37+
print("Missing quantized ops")
38+
39+
40+
class TestPropagateCustomMetaPass(unittest.TestCase):
41+
class ModuleLinear(torch.nn.Module):
42+
def __init__(
43+
self,
44+
in_size: int = 2,
45+
input_channels: int = 4,
46+
output_channels: int = 4,
47+
dtype: torch.dtype = torch.float,
48+
use_bias: bool = False,
49+
):
50+
super().__init__()
51+
self.linear = torch.nn.Linear(
52+
input_channels, output_channels, bias=use_bias
53+
).to(dtype=dtype)
54+
55+
self.ic = input_channels
56+
self.oc = output_channels
57+
assert dtype in [torch.float, torch.half], "Unsupported op dtype"
58+
self.op_dtype = dtype
59+
self.in_size = in_size
60+
61+
def forward(self, x: torch.Tensor):
62+
return self.linear(x)
63+
64+
def get_random_inputs(self):
65+
inp = torch.randn(self.in_size, self.ic).to(self.op_dtype)
66+
return (inp,)
67+
68+
class Export(BaseStages.Export):
69+
def run(
70+
self,
71+
artifact: torch.nn.Module,
72+
inputs: Tuple[torch.Tensor],
73+
) -> None:
74+
75+
tagged_module = torch.export.export(
76+
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
77+
).module()
78+
delegate_external_constants_pass_unlifted(
79+
module=tagged_module,
80+
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
81+
)
82+
self.exported_program = torch.export.export(
83+
tagged_module, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
84+
)
85+
86+
def _test_linear(
87+
self,
88+
partitioner: XnnpackPartitioner,
89+
quantization_stage: Union[BaseStages.Quantize, BaseStages.Quantize_],
90+
):
91+
eager_model = self.ModuleLinear(
92+
in_size=1,
93+
input_channels=32,
94+
output_channels=2,
95+
)
96+
test_inputs = eager_model.get_random_inputs()
97+
98+
tester = Tester(eager_model, test_inputs)
99+
tester.quantize(quantization_stage)
100+
tester.export(self.Export())
101+
tester.to_edge_transform_and_lower(
102+
ToEdgeTransformAndLower([partitioner])
103+
).to_executorch()
104+
tester.run_method_and_compare_outputs()
105+
106+
exec = tester.get_artifact()
107+
program_buffer = exec.buffer
108+
self.assertEqual(len(exec._tensor_data), 1)
109+
data_buffer = bytes(exec._tensor_data.pop("model"))
110+
self.assertTrue(len(data_buffer) > 200)
111+
from executorch.extension.pybindings import portable_lib as runtime
112+
113+
module = runtime._load_for_executorch_from_buffer(program_buffer, data_buffer)
114+
output = module.forward(test_inputs)
115+
reference_output = exec.exported_program().module()(
116+
test_inputs[0],
117+
)
118+
self.assertTrue(torch.allclose(output[0], reference_output, 1e-2))
119+
120+
# with self.assertRaises(RuntimeError):
121+
# runtime._load_for_executorch_from_buffer(program_buffer).forward(
122+
# test_inputs
123+
# )
124+
125+
def test_quantize_(self):
126+
# Quantize with torchao quantize_ API.
127+
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
128+
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
129+
per_op_mode=False,
130+
)
131+
linear_config = Int8DynamicActivationIntxWeightConfig(
132+
weight_dtype=torch.int4,
133+
weight_granularity=PerGroup(32),
134+
)
135+
self._test_linear(
136+
DynamicallyQuantizedPartitioner, BaseStages.Quantize_(config=linear_config)
137+
)
138+
139+
def test_pt2e_quantize(self):
140+
# Quantize with pt2e quantize.
141+
quant_configs = [
142+
# per_tensor
143+
get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False),
144+
# per_channel
145+
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=False),
146+
# per_channel_dynamic
147+
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True),
148+
]
149+
for quant_config in quant_configs:
150+
precision = (
151+
ConfigPrecisionType.DYNAMIC_QUANT
152+
if quant_config.input_activation.is_dynamic
153+
else ConfigPrecisionType.STATIC_QUANT
154+
)
155+
for per_op_mode in [True, False]:
156+
partitioner = XnnpackPartitioner(
157+
config_precisions=precision, per_op_mode=per_op_mode
158+
)
159+
self._test_linear(
160+
partitioner, XNNPackQuantize(quantization_config=quant_config)
161+
)

0 commit comments

Comments
 (0)