diff --git a/backends/test/harness/stages/__init__.py b/backends/test/harness/stages/__init__.py index 36ed435ebd7..14431191621 100644 --- a/backends/test/harness/stages/__init__.py +++ b/backends/test/harness/stages/__init__.py @@ -1,6 +1,6 @@ from .export import Export from .partition import Partition -from .quantize import Quantize +from .quantize import Quantize, Quantize_ from .run_passes import RunPasses from .serialize import Serialize from .stage import Stage, StageType @@ -12,6 +12,7 @@ "Export", "Partition", "Quantize", + "Quantize_", "RunPasses", "Serialize", "Stage", diff --git a/backends/test/harness/stages/quantize.py b/backends/test/harness/stages/quantize.py index 9edb600e19f..6c6036c8104 100644 --- a/backends/test/harness/stages/quantize.py +++ b/backends/test/harness/stages/quantize.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Callable, Optional, Sequence, Tuple import torch @@ -15,6 +15,8 @@ prepare_qat_pt2e, ) from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.quant_api import quantize_ +from torchao.utils import unwrap_tensor_subclass class Quantize(Stage): @@ -79,3 +81,48 @@ def graph_module(self) -> str: def run_artifact(self, inputs): return self.converted_graph.forward(*inputs) + + +class Quantize_(Stage): + """ + TorchAO quantization stage using the quantize_ API. + """ + + def __init__( + self, + config: Any, + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, + ): + """ + Args: + config: TorchAO quantization config (e.g., Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig) + filter_fn: Optional filter function to select which modules to quantize + """ + self.config = config + self.filter_fn = filter_fn + self.quantized_module = None + + def stage_type(self) -> str: + return StageType.QUANTIZE + + def run( + self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] + ) -> None: + # Apply quantize_ to the model + quantize_(artifact, self.config, self.filter_fn) + + # Unwrap tensor subclasses for export compatibility + unwrap_tensor_subclass(artifact) + + self.quantized_module = artifact + + @property + def artifact(self) -> torch.nn.Module: + return self.quantized_module + + @property + def graph_module(self) -> torch.nn.Module: + return self.quantized_module + + def run_artifact(self, inputs): + return self.quantized_module.forward(*inputs) diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index 141718bde6f..c48896b3d81 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -23,6 +23,9 @@ from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass +from executorch.backends.xnnpack._passes.propagate_custom_meta_pass import ( + PropagateCustomMetaPass, +) from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import ( RemoveRedundantCopyPass, ) @@ -59,6 +62,7 @@ def __init__( DimOrderOpsRevertPass, ConvertToUpsampleBilinear2d, ConvertToLinearPass, + PropagateCustomMetaPass, ConvertToSDPAPass, ConstPropPass, FuseBatchNormPass, diff --git a/backends/xnnpack/_passes/propagate_custom_meta_pass.py b/backends/xnnpack/_passes/propagate_custom_meta_pass.py new file mode 100644 index 00000000000..b1a03514446 --- /dev/null +++ b/backends/xnnpack/_passes/propagate_custom_meta_pass.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass +from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant +from executorch.exir.pass_base import PassResult + + +class PropagateCustomMetaPass(XNNPACKPass): + """ + Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes. + For all quantize/dequantize nodes in the graph, if the parent node has a + node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta. + """ + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + + for node in graph.nodes: + if not (is_quant(node) or is_dequant(node)): + continue + + # Get the parent node (first input argument) + if len(node.all_input_nodes) == 0: + continue + + parent_node = node.args[0] + if not isinstance(parent_node, torch.fx.Node): + continue + + if "custom" in parent_node.meta: + node.meta["custom"] = parent_node.meta["custom"] + + graph_module.recompile() + + # Since we are overriding "call", we need to call the parent's "call" + # to retrace the graph and regenerate metadata + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py b/backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py new file mode 100644 index 00000000000..f76ee3a0868 --- /dev/null +++ b/backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple, Union + +import executorch.backends.test.harness.stages as BaseStages + +import torch +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( + ConfigPrecisionType, +) +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, +) +from executorch.backends.xnnpack.test.tester import Quantize as XNNPackQuantize, Tester +from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower +from executorch.exir.passes.external_constants_pass import ( + delegate_external_constants_pass_unlifted, +) + +from torchao.quantization.granularity import PerGroup +from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig + +try: + import executorch.extension.pybindings.portable_lib # noqa[F401] + import executorch.kernels.quantized # noqa[F401] + + has_quantized_ops = True +except: + has_quantized_ops = False + print("Missing quantized ops") + + +class TestPropagateCustomMetaPass(unittest.TestCase): + class ModuleLinear(torch.nn.Module): + def __init__( + self, + in_size: int = 2, + input_channels: int = 4, + output_channels: int = 4, + dtype: torch.dtype = torch.float, + use_bias: bool = False, + ): + super().__init__() + self.linear = torch.nn.Linear( + input_channels, output_channels, bias=use_bias + ).to(dtype=dtype) + + self.ic = input_channels + self.oc = output_channels + assert dtype in [torch.float, torch.half], "Unsupported op dtype" + self.op_dtype = dtype + self.in_size = in_size + + def forward(self, x: torch.Tensor): + return self.linear(x) + + def get_random_inputs(self): + inp = torch.randn(self.in_size, self.ic).to(self.op_dtype) + return (inp,) + + class Export(BaseStages.Export): + def run( + self, + artifact: torch.nn.Module, + inputs: Tuple[torch.Tensor], + ) -> None: + + tagged_module = torch.export.export( + artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True + ).module() + delegate_external_constants_pass_unlifted( + module=tagged_module, + gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd" + ) + self.exported_program = torch.export.export( + tagged_module, inputs, dynamic_shapes=self.dynamic_shapes, strict=True + ) + + def _test_linear( + self, + partitioner: XnnpackPartitioner, + quantization_stage: Union[BaseStages.Quantize, BaseStages.Quantize_], + ): + eager_model = self.ModuleLinear( + in_size=1, + input_channels=32, + output_channels=2, + ) + test_inputs = eager_model.get_random_inputs() + + tester = Tester(eager_model, test_inputs) + tester.quantize(quantization_stage) + tester.export(self.Export()) + tester.to_edge_transform_and_lower( + ToEdgeTransformAndLower([partitioner]) + ).to_executorch() + tester.run_method_and_compare_outputs() + + exec = tester.get_artifact() + program_buffer = exec.buffer + self.assertEqual(len(exec._tensor_data), 1) + data_buffer = bytes(exec._tensor_data.pop("model")) + self.assertTrue(len(data_buffer) > 200) + from executorch.extension.pybindings import portable_lib as runtime + + module = runtime._load_for_executorch_from_buffer(program_buffer, data_buffer) + output = module.forward(test_inputs) + reference_output = exec.exported_program().module()( + test_inputs[0], + ) + self.assertTrue(torch.allclose(output[0], reference_output, 1e-2)) + + # with self.assertRaises(RuntimeError): + # runtime._load_for_executorch_from_buffer(program_buffer).forward( + # test_inputs + # ) + + def test_quantize_(self): + # Quantize with torchao quantize_ API. + DynamicallyQuantizedPartitioner = XnnpackPartitioner( + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, + per_op_mode=False, + ) + linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), + ) + self._test_linear( + DynamicallyQuantizedPartitioner, BaseStages.Quantize_(config=linear_config) + ) + + def test_pt2e_quantize(self): + # Quantize with pt2e quantize. + quant_configs = [ + # per_tensor + get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False), + # per_channel + get_symmetric_quantization_config(is_per_channel=True, is_dynamic=False), + # per_channel_dynamic + get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True), + ] + for quant_config in quant_configs: + precision = ( + ConfigPrecisionType.DYNAMIC_QUANT + if quant_config.input_activation.is_dynamic + else ConfigPrecisionType.STATIC_QUANT + ) + for per_op_mode in [True, False]: + partitioner = XnnpackPartitioner( + config_precisions=precision, per_op_mode=per_op_mode + ) + self._test_linear( + partitioner, XNNPackQuantize(quantization_config=quant_config) + )