diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index 141718bde6f..88f503d9ea0 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -19,6 +19,7 @@ from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import ( ConvertToUpsampleBilinear2d, ) +from executorch.backends.xnnpack._passes.external_constants_pass import PropagateCustomMetaPass from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass @@ -59,6 +60,9 @@ def __init__( DimOrderOpsRevertPass, ConvertToUpsampleBilinear2d, ConvertToLinearPass, + # Add pass here. + # Find qdq nodes, if the inputs are tagged, tagged the nodes. + PropagateCustomMetaPass, ConvertToSDPAPass, ConstPropPass, FuseBatchNormPass, @@ -92,4 +96,5 @@ def transform(self) -> ExportedProgram: f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}" ) ep = _transform(ep, transform_pass) + breakpoint() return ep diff --git a/backends/xnnpack/_passes/external_constants_pass.py b/backends/xnnpack/_passes/external_constants_pass.py new file mode 100644 index 00000000000..6d69f8c7fe3 --- /dev/null +++ b/backends/xnnpack/_passes/external_constants_pass.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass +from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant +from executorch.exir.pass_base import PassResult + + +class PropagateCustomMetaPass(XNNPACKPass): + """ + Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes. + + For all quantize/dequantize nodes in the graph, if the parent (input) node has a + node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta. + """ + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + + for node in graph.nodes: + # Check if this is a quantize or dequantize node + if not (is_quant(node) or is_dequant(node)): + continue + + # Get the parent node (first input argument) + if len(node.all_input_nodes) == 0: + continue + + parent_node = node.args[0] + if not isinstance(parent_node, torch.fx.Node): + continue + + # Check if parent has 'custom' metadata + if "custom" in parent_node.meta: + breakpoint() + # Copy the custom metadata to the q/dq node + node.meta["custom"] = parent_node.meta["custom"] + + graph_module.recompile() + + # Call parent's call method to retrace and regenerate metadata + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 68226644859..8f4991de659 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -296,6 +296,7 @@ def get_quant_params( offset=UINT64_MAX, size=num_bytes, named_key=scale_name ) ) + breakpoint() self._named_data_store.add_named_data( scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT ) @@ -630,6 +631,7 @@ def get_serialized_buffer_index( logging.info( f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store" ) + breakpoint() self._named_data_store.add_named_data( named_key, bytes(array), diff --git a/backends/xnnpack/operators/op_linear.py b/backends/xnnpack/operators/op_linear.py index dda1d3e53ef..e1e63014d41 100644 --- a/backends/xnnpack/operators/op_linear.py +++ b/backends/xnnpack/operators/op_linear.py @@ -51,6 +51,7 @@ def define_node( # filter weight_node = get_input_node(node, 1) + breakpoint() weight_quant_params = QuantParams.from_weights( weight_node, self._exported_program ) diff --git a/exir/program/_program.py b/exir/program/_program.py index 9298eb3e88d..9d484cc5ef4 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1372,6 +1372,7 @@ def to_edge_transform_and_lower( # noqa: C901 for name, partitioner_list in partitioner.items(): if i < len(partitioner_list): method_to_partitioner[name] = partitioner_list[i] + breakpoint() edge_manager = edge_manager.to_backend(method_to_partitioner) for name, program in edge_manager._edge_programs.items(): diff --git a/export_quant_xnnpack.py b/export_quant_xnnpack.py new file mode 100644 index 00000000000..983d7e850cd --- /dev/null +++ b/export_quant_xnnpack.py @@ -0,0 +1,165 @@ + +import torch + +from torchao.quantization.granularity import PerGroup, PerAxis +from torchao.quantization.quant_api import ( + IntxWeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + quantize_, +) +from torchao.utils import unwrap_tensor_subclass +from torch.export import export, ExportedProgram +from executorch.exir import ( + EdgeProgramManager, + ExecutorchBackendConfig, + ExecutorchProgramManager, +) +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( + ConfigPrecisionType, +) +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackFloatingPointPartitioner, + XnnpackPartitioner, +) +from executorch.exir import ( + EdgeCompileConfig, + EdgeProgramManager, + to_edge_transform_and_lower, +) + +# Note: I think this works fine. +# Quantize embeddings with 8-bits, per channel +# embedding_config = IntxWeightOnlyConfig( +# weight_dtype=torch.int8, +# granularity=PerAxis(0), +# ) +# qunatize_( +# eager_model, +# lambda m, fqn: isinstance(m, torch.nn.Embedding), +# ) + +torch.manual_seed(0) + +class ModuleLinear(torch.nn.Module): + def __init__( + self, + in_size: int = 2, + input_channels: int = 4, + output_channels: int = 4, + dtype: torch.dtype = torch.float, + use_bias: bool = False + ): + super().__init__() + self.linear = torch.nn.Linear( + input_channels, output_channels, bias=use_bias + ).to(dtype=dtype) + + self.ic = input_channels + self.oc = output_channels + assert dtype in [torch.float, torch.half], "Unsupported op dtype" + self.op_dtype = dtype + self.in_size = in_size + + def forward(self, x: torch.Tensor): + return self.linear(x) + + def get_random_inputs(self): + inp = torch.randn(self.in_size, self.ic).to(self.op_dtype) + return (inp,) + +### EAGER. +eager_model = ModuleLinear( + in_size=1, + input_channels=32, + output_channels=2, +) + +test_inputs = eager_model.get_random_inputs() +eager_result = eager_model(*test_inputs) +print("eager result: ", eager_result) + +### QUANTIZE. +# Quatize linear layers with 8-bit dynamic activations and 4-bit weights +linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), +) +# NOTE: comment this out, and program-data separation works well. +quantize_(eager_model, linear_config) + +quantized_result = eager_model(*test_inputs) +print("quantized results: ", quantized_result) +print(torch.allclose(eager_result, quantized_result, atol=1e-1)) + +unwrap_tensor_subclass(eager_model) +unwrapped_result = eager_model(*test_inputs) +print("unwrapped results: ", unwrapped_result) +print(torch.allclose(quantized_result, unwrapped_result, atol=1e-3)) + +from executorch.exir.passes.external_constants_pass import ( + delegate_external_constants_pass_unlifted, +) +### EXPORT AND TAG WEIGHTS. +ep1 = export(eager_model, test_inputs, dynamic_shapes=None, strict=True) +exported_result = ep1.module()(*test_inputs) +print("exported program: ", exported_result) +print(torch.allclose(quantized_result, exported_result, atol=1e-3)) +print("Graph: ") +ep1.graph_module.print_readable() +# Tag the unlifted ep.module(). +tagged_module = ep1.module() +delegate_external_constants_pass_unlifted( + module=tagged_module, + gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd" +) + +### RE-EXPORT. +ep = export(tagged_module, test_inputs, dynamic_shapes=None, strict=True) +exported_result = ep.module()(*test_inputs) +print("exported program (after tagging): ", exported_result) +print(torch.allclose(quantized_result, exported_result, atol=1e-3)) +# Check tagged nodes: +for node in list(ep.graph.nodes): + if 'custom' in node.meta: + print(f"Node: {node.name}, meta: {node.meta['custom']}") + +## TO_EDGE_TRANSFORM_AND_LOWER. +DynamicallyQuantizedPartitioner = XnnpackPartitioner( + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, + per_op_mode=True, +) +edge = to_edge_transform_and_lower( + ep, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + partitioner=[XnnpackPartitioner()], + generate_etrecord=False, +) +# ^ after this, the graph has a single node? torchao_dequantize_affine_default +edge_result = edge.exported_program().module()(*test_inputs) +print("edge program: ", edge_result) +print(torch.allclose(quantized_result, edge_result, atol=1e-3)) +edge.exported_program().graph_module.print_readable() + +### TO_EXECUTORCH. +exec = edge.to_executorch(ExecutorchBackendConfig()) + +### SAVE ET MODEL TO DISK. +with open("model.pte", "wb") as f: + f.write(exec.buffer) + +if len(exec._tensor_data) == 0: + print("No external data saved") +else: + exec.write_tensor_data_to_file(".") + +### LOAD AND RUN VIA PYBINDINGS. +import torch +from executorch.extension.pybindings import portable_lib +module = portable_lib._load_for_executorch("model.pte", "model.ptd") +exec_result = module.forward(test_inputs) +# Expecting key: a4f6ff98c9db8ecfe5c11e87d07f182a58cb9696f01086d9e0cdc2e986fab003 +# Scale: a4f6ff98c9db8ecfe5c11e87d07f182a58cb9696f01086d9e0cdc2e986fab003 +print("executorch program: ", exec_result) +print(torch.allclose(quantized_result, exec_result[0], atol=1e-3)) + +print("End.")