Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/test/harness/stages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .export import Export
from .partition import Partition
from .quantize import Quantize
from .quantize import Quantize, Quantize_
from .run_passes import RunPasses
from .serialize import Serialize
from .stage import Stage, StageType
Expand All @@ -12,6 +12,7 @@
"Export",
"Partition",
"Quantize",
"Quantize_",
"RunPasses",
"Serialize",
"Stage",
Expand Down
49 changes: 48 additions & 1 deletion backends/test/harness/stages/quantize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Callable, Optional, Sequence, Tuple

import torch

Expand All @@ -15,6 +15,8 @@
prepare_qat_pt2e,
)
from torchao.quantization.pt2e.quantizer import Quantizer
from torchao.quantization.quant_api import quantize_
from torchao.utils import unwrap_tensor_subclass


class Quantize(Stage):
Expand Down Expand Up @@ -79,3 +81,48 @@ def graph_module(self) -> str:

def run_artifact(self, inputs):
return self.converted_graph.forward(*inputs)


class Quantize_(Stage):
"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GregoryComer let me know if I should move this into the test class, instead of the test harness, or rename.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I like it here.

TorchAO quantization stage using the quantize_ API.
"""

def __init__(
self,
config: Any,
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
):
"""
Args:
config: TorchAO quantization config (e.g., Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig)
filter_fn: Optional filter function to select which modules to quantize
"""
self.config = config
self.filter_fn = filter_fn
self.quantized_module = None

def stage_type(self) -> str:
return StageType.QUANTIZE

def run(
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
) -> None:
# Apply quantize_ to the model
quantize_(artifact, self.config, self.filter_fn)

# Unwrap tensor subclasses for export compatibility
unwrap_tensor_subclass(artifact)

self.quantized_module = artifact

@property
def artifact(self) -> torch.nn.Module:
return self.quantized_module

@property
def graph_module(self) -> torch.nn.Module:
return self.quantized_module

def run_artifact(self, inputs):
return self.quantized_module.forward(*inputs)
4 changes: 4 additions & 0 deletions backends/xnnpack/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
from executorch.backends.xnnpack._passes.propagate_custom_meta_pass import (
PropagateCustomMetaPass,
)
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
RemoveRedundantCopyPass,
)
Expand Down Expand Up @@ -59,6 +62,7 @@ def __init__(
DimOrderOpsRevertPass,
ConvertToUpsampleBilinear2d,
ConvertToLinearPass,
PropagateCustomMetaPass,
ConvertToSDPAPass,
ConstPropPass,
FuseBatchNormPass,
Expand Down
44 changes: 44 additions & 0 deletions backends/xnnpack/_passes/propagate_custom_meta_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
from executorch.exir.pass_base import PassResult


class PropagateCustomMetaPass(XNNPACKPass):
"""
Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes.
For all quantize/dequantize nodes in the graph, if the parent node has a
node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta.
"""

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph

for node in graph.nodes:
if not (is_quant(node) or is_dequant(node)):
continue

# Get the parent node (first input argument)
if len(node.all_input_nodes) == 0:
continue

parent_node = node.args[0]
if not isinstance(parent_node, torch.fx.Node):
continue

if "custom" in parent_node.meta:
node.meta["custom"] = parent_node.meta["custom"]

graph_module.recompile()

# Since we are overriding "call", we need to call the parent's "call"
# to retrace the graph and regenerate metadata
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
161 changes: 161 additions & 0 deletions backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from typing import Tuple, Union

import executorch.backends.test.harness.stages as BaseStages

import torch
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from executorch.backends.xnnpack.test.tester import Quantize as XNNPackQuantize, Tester
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
from executorch.exir.passes.external_constants_pass import (
delegate_external_constants_pass_unlifted,
)

from torchao.quantization.granularity import PerGroup
from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig

try:
import executorch.extension.pybindings.portable_lib # noqa[F401]
import executorch.kernels.quantized # noqa[F401]

has_quantized_ops = True
except:
has_quantized_ops = False
print("Missing quantized ops")


class TestPropagateCustomMetaPass(unittest.TestCase):
class ModuleLinear(torch.nn.Module):
def __init__(
self,
in_size: int = 2,
input_channels: int = 4,
output_channels: int = 4,
dtype: torch.dtype = torch.float,
use_bias: bool = False,
):
super().__init__()
self.linear = torch.nn.Linear(
input_channels, output_channels, bias=use_bias
).to(dtype=dtype)

self.ic = input_channels
self.oc = output_channels
assert dtype in [torch.float, torch.half], "Unsupported op dtype"
self.op_dtype = dtype
self.in_size = in_size

def forward(self, x: torch.Tensor):
return self.linear(x)

def get_random_inputs(self):
inp = torch.randn(self.in_size, self.ic).to(self.op_dtype)
return (inp,)

class Export(BaseStages.Export):
def run(
self,
artifact: torch.nn.Module,
inputs: Tuple[torch.Tensor],
) -> None:

tagged_module = torch.export.export(
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
).module()
delegate_external_constants_pass_unlifted(
module=tagged_module,
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
)
self.exported_program = torch.export.export(
tagged_module, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
)

def _test_linear(
self,
partitioner: XnnpackPartitioner,
quantization_stage: Union[BaseStages.Quantize, BaseStages.Quantize_],
):
eager_model = self.ModuleLinear(
in_size=1,
input_channels=32,
output_channels=2,
)
test_inputs = eager_model.get_random_inputs()

tester = Tester(eager_model, test_inputs)
tester.quantize(quantization_stage)
tester.export(self.Export())
tester.to_edge_transform_and_lower(
ToEdgeTransformAndLower([partitioner])
).to_executorch()
tester.run_method_and_compare_outputs()

exec = tester.get_artifact()
program_buffer = exec.buffer
self.assertEqual(len(exec._tensor_data), 1)
data_buffer = bytes(exec._tensor_data.pop("model"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to assert the size of it? Just to make sure it is indeed the quantized weight tensor. Also I would like (somehow) to validate that we also didn't put this in the blob, perhaps by asserting that the blob size is < weight_size (if we have large-ish weights).

Copy link
Contributor Author

@lucylq lucylq Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - I'll add a check on size.

Re validating that it's not in the blob, we can check that forward fails when we do not pass in the data buffer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@digantdesai added a check on size, and check on accuracy.

Verified locally that we're missing the weight if we do not pass in the data buffer. The test segfaults after ~4 runs though, maybe something isn't cleaned up properly in pybindings. Left that as a todo for now.

self.assertTrue(len(data_buffer) > 200)
from executorch.extension.pybindings import portable_lib as runtime

module = runtime._load_for_executorch_from_buffer(program_buffer, data_buffer)
output = module.forward(test_inputs)
reference_output = exec.exported_program().module()(
test_inputs[0],
)
self.assertTrue(torch.allclose(output[0], reference_output, 1e-2))

# with self.assertRaises(RuntimeError):
# runtime._load_for_executorch_from_buffer(program_buffer).forward(
# test_inputs
# )

def test_quantize_(self):
# Quantize with torchao quantize_ API.
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
per_op_mode=False,
)
linear_config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
)
self._test_linear(
DynamicallyQuantizedPartitioner, BaseStages.Quantize_(config=linear_config)
)

def test_pt2e_quantize(self):
# Quantize with pt2e quantize.
quant_configs = [
# per_tensor
get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False),
# per_channel
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=False),
# per_channel_dynamic
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True),
]
for quant_config in quant_configs:
precision = (
ConfigPrecisionType.DYNAMIC_QUANT
if quant_config.input_activation.is_dynamic
else ConfigPrecisionType.STATIC_QUANT
)
for per_op_mode in [True, False]:
partitioner = XnnpackPartitioner(
config_precisions=precision, per_op_mode=per_op_mode
)
self._test_linear(
partitioner, XNNPackQuantize(quantization_config=quant_config)
)
Loading