-
Notifications
You must be signed in to change notification settings - Fork 689
Add xnnpack pass to propagate custom meta field to q/dq nodes #14864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass | ||
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant | ||
from executorch.exir.pass_base import PassResult | ||
|
||
|
||
class PropagateCustomMetaPass(XNNPACKPass): | ||
""" | ||
Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes. | ||
For all quantize/dequantize nodes in the graph, if the parent node has a | ||
node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta. | ||
""" | ||
|
||
def call(self, graph_module: torch.fx.GraphModule): | ||
graph = graph_module.graph | ||
|
||
for node in graph.nodes: | ||
if not (is_quant(node) or is_dequant(node)): | ||
continue | ||
|
||
# Get the parent node (first input argument) | ||
if len(node.all_input_nodes) == 0: | ||
continue | ||
|
||
parent_node = node.args[0] | ||
if not isinstance(parent_node, torch.fx.Node): | ||
continue | ||
|
||
if "custom" in parent_node.meta: | ||
node.meta["custom"] = parent_node.meta["custom"] | ||
|
||
graph_module.recompile() | ||
|
||
# Since we are overriding "call", we need to call the parent's "call" | ||
# to retrace the graph and regenerate metadata | ||
graph_module = super().call(graph_module).graph_module | ||
|
||
return PassResult(graph_module, True) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
from typing import Tuple, Union | ||
|
||
import executorch.backends.test.harness.stages as BaseStages | ||
|
||
import torch | ||
from executorch.backends.xnnpack.partition.config.xnnpack_config import ( | ||
ConfigPrecisionType, | ||
) | ||
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner | ||
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( | ||
get_symmetric_quantization_config, | ||
) | ||
from executorch.backends.xnnpack.test.tester import Quantize as XNNPackQuantize, Tester | ||
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower | ||
from executorch.exir.passes.external_constants_pass import ( | ||
delegate_external_constants_pass_unlifted, | ||
) | ||
|
||
from torchao.quantization.granularity import PerGroup | ||
from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig | ||
|
||
try: | ||
import executorch.extension.pybindings.portable_lib # noqa[F401] | ||
import executorch.kernels.quantized # noqa[F401] | ||
|
||
has_quantized_ops = True | ||
except: | ||
has_quantized_ops = False | ||
print("Missing quantized ops") | ||
|
||
|
||
class TestPropagateCustomMetaPass(unittest.TestCase): | ||
class ModuleLinear(torch.nn.Module): | ||
def __init__( | ||
self, | ||
in_size: int = 2, | ||
input_channels: int = 4, | ||
output_channels: int = 4, | ||
dtype: torch.dtype = torch.float, | ||
use_bias: bool = False, | ||
): | ||
super().__init__() | ||
self.linear = torch.nn.Linear( | ||
input_channels, output_channels, bias=use_bias | ||
).to(dtype=dtype) | ||
|
||
self.ic = input_channels | ||
self.oc = output_channels | ||
assert dtype in [torch.float, torch.half], "Unsupported op dtype" | ||
self.op_dtype = dtype | ||
self.in_size = in_size | ||
|
||
def forward(self, x: torch.Tensor): | ||
return self.linear(x) | ||
|
||
def get_random_inputs(self): | ||
inp = torch.randn(self.in_size, self.ic).to(self.op_dtype) | ||
return (inp,) | ||
|
||
class Export(BaseStages.Export): | ||
def run( | ||
self, | ||
artifact: torch.nn.Module, | ||
inputs: Tuple[torch.Tensor], | ||
) -> None: | ||
|
||
tagged_module = torch.export.export( | ||
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True | ||
).module() | ||
delegate_external_constants_pass_unlifted( | ||
module=tagged_module, | ||
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd" | ||
) | ||
self.exported_program = torch.export.export( | ||
tagged_module, inputs, dynamic_shapes=self.dynamic_shapes, strict=True | ||
) | ||
|
||
def _test_linear( | ||
self, | ||
partitioner: XnnpackPartitioner, | ||
quantization_stage: Union[BaseStages.Quantize, BaseStages.Quantize_], | ||
): | ||
eager_model = self.ModuleLinear( | ||
in_size=1, | ||
input_channels=32, | ||
output_channels=2, | ||
) | ||
test_inputs = eager_model.get_random_inputs() | ||
|
||
tester = Tester(eager_model, test_inputs) | ||
tester.quantize(quantization_stage) | ||
tester.export(self.Export()) | ||
tester.to_edge_transform_and_lower( | ||
ToEdgeTransformAndLower([partitioner]) | ||
).to_executorch() | ||
tester.run_method_and_compare_outputs() | ||
|
||
exec = tester.get_artifact() | ||
program_buffer = exec.buffer | ||
self.assertEqual(len(exec._tensor_data), 1) | ||
data_buffer = bytes(exec._tensor_data.pop("model")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to assert the size of it? Just to make sure it is indeed the quantized weight tensor. Also I would like (somehow) to validate that we also didn't put this in the blob, perhaps by asserting that the blob size is < weight_size (if we have large-ish weights). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks - I'll add a check on size. Re validating that it's not in the blob, we can check that forward fails when we do not pass in the data buffer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @digantdesai added a check on size, and check on accuracy. Verified locally that we're missing the weight if we do not pass in the data buffer. The test segfaults after ~4 runs though, maybe something isn't cleaned up properly in pybindings. Left that as a todo for now. |
||
self.assertTrue(len(data_buffer) > 200) | ||
from executorch.extension.pybindings import portable_lib as runtime | ||
|
||
module = runtime._load_for_executorch_from_buffer(program_buffer, data_buffer) | ||
output = module.forward(test_inputs) | ||
reference_output = exec.exported_program().module()( | ||
test_inputs[0], | ||
) | ||
self.assertTrue(torch.allclose(output[0], reference_output, 1e-2)) | ||
|
||
# with self.assertRaises(RuntimeError): | ||
# runtime._load_for_executorch_from_buffer(program_buffer).forward( | ||
# test_inputs | ||
# ) | ||
|
||
def test_quantize_(self): | ||
# Quantize with torchao quantize_ API. | ||
DynamicallyQuantizedPartitioner = XnnpackPartitioner( | ||
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, | ||
per_op_mode=False, | ||
) | ||
linear_config = Int8DynamicActivationIntxWeightConfig( | ||
weight_dtype=torch.int4, | ||
weight_granularity=PerGroup(32), | ||
) | ||
self._test_linear( | ||
DynamicallyQuantizedPartitioner, BaseStages.Quantize_(config=linear_config) | ||
) | ||
|
||
def test_pt2e_quantize(self): | ||
# Quantize with pt2e quantize. | ||
quant_configs = [ | ||
# per_tensor | ||
get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False), | ||
# per_channel | ||
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=False), | ||
# per_channel_dynamic | ||
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True), | ||
] | ||
for quant_config in quant_configs: | ||
precision = ( | ||
ConfigPrecisionType.DYNAMIC_QUANT | ||
if quant_config.input_activation.is_dynamic | ||
else ConfigPrecisionType.STATIC_QUANT | ||
) | ||
for per_op_mode in [True, False]: | ||
partitioner = XnnpackPartitioner( | ||
config_precisions=precision, per_op_mode=per_op_mode | ||
) | ||
self._test_linear( | ||
partitioner, XNNPackQuantize(quantization_config=quant_config) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@GregoryComer let me know if I should move this into the test class, instead of the test harness, or rename.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, I like it here.