Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .ci/scripts/test_llama_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
model.dtype_override="fp32" \
backend.xnnpack.enabled=true \
backend.xnnpack.extended_ops=true \
export.output_name="${MODEL_SEPARATE}.pte" \
export.foundation_weights_file="${MODEL_SEPARATE}.ptd"
quantization.pt2e_quantize="xnnpack_dynamic" \
export.output_name="${MODEL}.pte" \
export.foundation_weights_file="${MODEL}.ptd"

# Run llama runner.
NOW=$(date +"%H:%M:%S")
Expand Down
3 changes: 2 additions & 1 deletion backends/test/harness/stages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .export import Export
from .partition import Partition
from .quantize import Quantize
from .quantize import Quantize, Quantize_
from .run_passes import RunPasses
from .serialize import Serialize
from .stage import Stage, StageType
Expand All @@ -12,6 +12,7 @@
"Export",
"Partition",
"Quantize",
"Quantize_",
"RunPasses",
"Serialize",
"Stage",
Expand Down
57 changes: 56 additions & 1 deletion backends/test/harness/stages/quantize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Callable, Optional, Sequence, Tuple

import torch

Expand All @@ -15,6 +15,16 @@
prepare_qat_pt2e,
)
from torchao.quantization.pt2e.quantizer import Quantizer
from torchao.quantization.quant_api import quantize_
from torchao.utils import unwrap_tensor_subclass

from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
quantize_,
)

from torchao.utils import unwrap_tensor_subclass


class Quantize(Stage):
Expand Down Expand Up @@ -79,3 +89,48 @@ def graph_module(self) -> str:

def run_artifact(self, inputs):
return self.converted_graph.forward(*inputs)


class Quantize_(Stage):
"""
TorchAO quantization stage using the quantize_ API.
"""

def __init__(
self,
config: Any,
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
):
"""
Args:
config: TorchAO quantization config (e.g., Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig)
filter_fn: Optional filter function to select which modules to quantize
"""
self.config = config
self.filter_fn = filter_fn
self.quantized_module = None

def stage_type(self) -> str:
return StageType.QUANTIZE

def run(
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
) -> None:
# Apply quantize_ to the model
quantize_(artifact, self.config, self.filter_fn)

# Unwrap tensor subclasses for export compatibility
unwrap_tensor_subclass(artifact)

self.quantized_module = artifact

@property
def artifact(self) -> torch.nn.Module:
return self.quantized_module

@property
def graph_module(self) -> torch.nn.Module:
return self.quantized_module

def run_artifact(self, inputs):
return self.quantized_module.forward(*inputs)
1 change: 1 addition & 0 deletions backends/test/harness/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Export,
Partition,
Quantize,
Quantize_,
RunPasses,
Serialize,
Stage,
Expand Down
4 changes: 4 additions & 0 deletions backends/xnnpack/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
from executorch.backends.xnnpack._passes.propagate_custom_meta_pass import (
PropagateCustomMetaPass,
)
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
RemoveRedundantCopyPass,
)
Expand Down Expand Up @@ -59,6 +62,7 @@ def __init__(
DimOrderOpsRevertPass,
ConvertToUpsampleBilinear2d,
ConvertToLinearPass,
PropagateCustomMetaPass,
ConvertToSDPAPass,
ConstPropPass,
FuseBatchNormPass,
Expand Down
45 changes: 45 additions & 0 deletions backends/xnnpack/_passes/propagate_custom_meta_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
from executorch.exir.pass_base import PassResult


class PropagateCustomMetaPass(XNNPACKPass):
"""
Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes.
For all quantize/dequantize nodes in the graph, if the parent node has a
node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta.
"""

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph

for node in graph.nodes:
if not (is_quant(node) or is_dequant(node)):
continue

# Get the parent node (first input argument)
if len(node.all_input_nodes) == 0:
continue

parent_node = node.args[0]
if not isinstance(parent_node, torch.fx.Node):
continue

if "custom" in parent_node.meta:
print(f"PROPAGATING CUSTOM META FROM {parent_node.name} TO {node.name}")
node.meta["custom"] = parent_node.meta["custom"]

graph_module.recompile()

# Since we are overriding "call", we need to call the parent's "call"
# to retrace the graph and regenerate metadata
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
2 changes: 2 additions & 0 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def get_quant_params(
offset=UINT64_MAX, size=num_bytes, named_key=scale_name
)
)
print(f"NDM: adding scale tensor with key {scale_name}")
self._named_data_store.add_named_data(
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
)
Expand Down Expand Up @@ -630,6 +631,7 @@ def get_serialized_buffer_index(
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
print(f"NDM: Adding constant data with name {tensor.name}, key {named_key} and tag {external_tag}")
self._named_data_store.add_named_data(
named_key,
bytes(array),
Expand Down
158 changes: 158 additions & 0 deletions backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from typing import Tuple, Union

import torch
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from executorch.backends.xnnpack.test.tester import Quantize as XNNPackQuantize, Tester
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
from executorch.exir.passes.external_constants_pass import (
delegate_external_constants_pass_unlifted,
)

from torchao.quantization.granularity import PerGroup
from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig

try:
import executorch.extension.pybindings.portable_lib # noqa[F401]
import executorch.kernels.quantized # noqa[F401]

has_quantized_ops = True
except:
has_quantized_ops = False
print("Missing quantized ops")

class TestPropagateCustomMetaPass(unittest.TestCase):
class ModuleLinear(torch.nn.Module):
def __init__(
self,
in_size: int = 2,
input_channels: int = 4,
output_channels: int = 4,
dtype: torch.dtype = torch.float,
use_bias: bool = False,
):
super().__init__()
self.linear = torch.nn.Linear(
input_channels, output_channels, bias=use_bias
).to(dtype=dtype)

self.ic = input_channels
self.oc = output_channels
assert dtype in [torch.float, torch.half], "Unsupported op dtype"
self.op_dtype = dtype
self.in_size = in_size

def forward(self, x: torch.Tensor):
return self.linear(x)

def get_random_inputs(self):
inp = torch.randn(self.in_size, self.ic).to(self.op_dtype)
return (inp,)

class Export(BaseStages.Export):
def run(
self,
artifact: torch.nn.Module,
inputs: Tuple[torch.Tensor],
) -> None:

tagged_module = torch.export.export(
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
).module()
delegate_external_constants_pass_unlifted(
module=tagged_module,
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
)
self.exported_program = torch.export.export(
tagged_module, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
)

def _test_linear(
self,
partitioner: XnnpackPartitioner,
quantization_stage: Union[BaseStages.Quantize, BaseStages.Quantize_]
):
eager_model = self.ModuleLinear(
in_size=1,
input_channels=32,
output_channels=2,
)
test_inputs = eager_model.get_random_inputs()

tester = Tester(eager_model, test_inputs)
tester.quantize(quantization_stage)
tester.export(self.Export())
tester.to_edge_transform_and_lower(
ToEdgeTransformAndLower([partitioner])
).to_executorch()
tester.run_method_and_compare_outputs()

exec = tester.get_artifact()
program_buffer = exec.buffer
self.assertEqual(len(exec._tensor_data), 1)
data_buffer = bytes(exec._tensor_data.pop("model"))
self.assertTrue(len(data_buffer) > 200)
from executorch.extension.pybindings import portable_lib as runtime

module = runtime._load_for_executorch_from_buffer(program_buffer, data_buffer)
output = module.forward(test_inputs)
reference_output = exec.exported_program().module()(
test_inputs[0],
)
self.assertTrue(torch.allclose(output[0], reference_output))

# TODO(lfq): This fails correctly, but segmentation faults after a few runs.
# with self.assertRaises(RuntimeError):
# runtime._load_for_executorch_from_buffer(program_buffer).forward(
# test_inputs
# )

def test_quantize_(self):
# Quantize with torchao quantize_ API.
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
per_op_mode=False,
)
linear_config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
)
self._test_linear(
DynamicallyQuantizedPartitioner,
BaseStages.Quantize_(config=linear_config))

def test_pt2e_quantize(self):
# Quantize with pt2e quantize.
quant_configs = [
# per_channel
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=False),
# per_tensor
get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False),
# per_channel_dynamic
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True),
]
partitioners = []
for config_precision in [ConfigPrecisionType.STATIC_QUANT, ConfigPrecisionType.DYNAMIC_QUANT]:
for per_op_mode in [True, False]:
partitioners.append(
XnnpackPartitioner(
config_precisions=config_precision,
per_op_mode=per_op_mode,
)
)
for quant_config in quant_configs:
for partitioner in partitioners:
self._test_linear(partitioner, XNNPackQuantize(quantization_config=quant_config))
Loading
Loading