Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/xnnpack/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
ConvertToUpsampleBilinear2d,
)
from executorch.backends.xnnpack._passes.external_constants_pass import PropagateCustomMetaPass
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
Expand Down Expand Up @@ -59,6 +60,9 @@ def __init__(
DimOrderOpsRevertPass,
ConvertToUpsampleBilinear2d,
ConvertToLinearPass,
# Add pass here.
# Find qdq nodes, if the inputs are tagged, tagged the nodes.
PropagateCustomMetaPass,
ConvertToSDPAPass,
ConstPropPass,
FuseBatchNormPass,
Expand Down Expand Up @@ -92,4 +96,5 @@ def transform(self) -> ExportedProgram:
f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}"
)
ep = _transform(ep, transform_pass)
breakpoint()
return ep
48 changes: 48 additions & 0 deletions backends/xnnpack/_passes/external_constants_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
from executorch.exir.pass_base import PassResult


class PropagateCustomMetaPass(XNNPACKPass):
"""
Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes.

For all quantize/dequantize nodes in the graph, if the parent (input) node has a
node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta.
"""

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph

for node in graph.nodes:
# Check if this is a quantize or dequantize node
if not (is_quant(node) or is_dequant(node)):
continue

# Get the parent node (first input argument)
if len(node.all_input_nodes) == 0:
continue

parent_node = node.args[0]
if not isinstance(parent_node, torch.fx.Node):
continue

# Check if parent has 'custom' metadata
if "custom" in parent_node.meta:
breakpoint()
# Copy the custom metadata to the q/dq node
node.meta["custom"] = parent_node.meta["custom"]

graph_module.recompile()

# Call parent's call method to retrace and regenerate metadata
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
2 changes: 2 additions & 0 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def get_quant_params(
offset=UINT64_MAX, size=num_bytes, named_key=scale_name
)
)
breakpoint()
self._named_data_store.add_named_data(
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
)
Expand Down Expand Up @@ -630,6 +631,7 @@ def get_serialized_buffer_index(
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
breakpoint()
self._named_data_store.add_named_data(
named_key,
bytes(array),
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/operators/op_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def define_node(

# filter
weight_node = get_input_node(node, 1)
breakpoint()
weight_quant_params = QuantParams.from_weights(
weight_node, self._exported_program
)
Expand Down
1 change: 1 addition & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,7 @@ def to_edge_transform_and_lower( # noqa: C901
for name, partitioner_list in partitioner.items():
if i < len(partitioner_list):
method_to_partitioner[name] = partitioner_list[i]
breakpoint()
edge_manager = edge_manager.to_backend(method_to_partitioner)

for name, program in edge_manager._edge_programs.items():
Expand Down
165 changes: 165 additions & 0 deletions export_quant_xnnpack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@

import torch

from torchao.quantization.granularity import PerGroup, PerAxis
from torchao.quantization.quant_api import (
IntxWeightOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass
from torch.export import export, ExportedProgram
from executorch.exir import (
EdgeProgramManager,
ExecutorchBackendConfig,
ExecutorchProgramManager,
)
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackFloatingPointPartitioner,
XnnpackPartitioner,
)
from executorch.exir import (
EdgeCompileConfig,
EdgeProgramManager,
to_edge_transform_and_lower,
)

# Note: I think this works fine.
# Quantize embeddings with 8-bits, per channel
# embedding_config = IntxWeightOnlyConfig(
# weight_dtype=torch.int8,
# granularity=PerAxis(0),
# )
# qunatize_(
# eager_model,
# lambda m, fqn: isinstance(m, torch.nn.Embedding),
# )

torch.manual_seed(0)

class ModuleLinear(torch.nn.Module):
def __init__(
self,
in_size: int = 2,
input_channels: int = 4,
output_channels: int = 4,
dtype: torch.dtype = torch.float,
use_bias: bool = False
):
super().__init__()
self.linear = torch.nn.Linear(
input_channels, output_channels, bias=use_bias
).to(dtype=dtype)

self.ic = input_channels
self.oc = output_channels
assert dtype in [torch.float, torch.half], "Unsupported op dtype"
self.op_dtype = dtype
self.in_size = in_size

def forward(self, x: torch.Tensor):
return self.linear(x)

def get_random_inputs(self):
inp = torch.randn(self.in_size, self.ic).to(self.op_dtype)
return (inp,)

### EAGER.
eager_model = ModuleLinear(
in_size=1,
input_channels=32,
output_channels=2,
)

test_inputs = eager_model.get_random_inputs()
eager_result = eager_model(*test_inputs)
print("eager result: ", eager_result)

### QUANTIZE.
# Quatize linear layers with 8-bit dynamic activations and 4-bit weights
linear_config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,

Check failure on line 84 in export_quant_xnnpack.py

View workflow job for this annotation

GitHub Actions / lintrunner-mypy / linux-job

MYPY attr-defined

Module has no attribute "int4" To disable, use ` # type: ignore[attr-defined]`
weight_granularity=PerGroup(32),
)
# NOTE: comment this out, and program-data separation works well.
quantize_(eager_model, linear_config)

quantized_result = eager_model(*test_inputs)
print("quantized results: ", quantized_result)
print(torch.allclose(eager_result, quantized_result, atol=1e-1))

unwrap_tensor_subclass(eager_model)
unwrapped_result = eager_model(*test_inputs)
print("unwrapped results: ", unwrapped_result)
print(torch.allclose(quantized_result, unwrapped_result, atol=1e-3))

from executorch.exir.passes.external_constants_pass import (
delegate_external_constants_pass_unlifted,
)
### EXPORT AND TAG WEIGHTS.
ep1 = export(eager_model, test_inputs, dynamic_shapes=None, strict=True)
exported_result = ep1.module()(*test_inputs)
print("exported program: ", exported_result)
print(torch.allclose(quantized_result, exported_result, atol=1e-3))
print("Graph: ")
ep1.graph_module.print_readable()
# Tag the unlifted ep.module().
tagged_module = ep1.module()
delegate_external_constants_pass_unlifted(
module=tagged_module,
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
)

### RE-EXPORT.
ep = export(tagged_module, test_inputs, dynamic_shapes=None, strict=True)
exported_result = ep.module()(*test_inputs)
print("exported program (after tagging): ", exported_result)
print(torch.allclose(quantized_result, exported_result, atol=1e-3))
# Check tagged nodes:
for node in list(ep.graph.nodes):
if 'custom' in node.meta:
print(f"Node: {node.name}, meta: {node.meta['custom']}")

## TO_EDGE_TRANSFORM_AND_LOWER.
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
per_op_mode=True,
)
edge = to_edge_transform_and_lower(
ep,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
generate_etrecord=False,
)
# ^ after this, the graph has a single node? torchao_dequantize_affine_default
edge_result = edge.exported_program().module()(*test_inputs)
print("edge program: ", edge_result)
print(torch.allclose(quantized_result, edge_result, atol=1e-3))
edge.exported_program().graph_module.print_readable()

### TO_EXECUTORCH.
exec = edge.to_executorch(ExecutorchBackendConfig())

### SAVE ET MODEL TO DISK.
with open("model.pte", "wb") as f:
f.write(exec.buffer)

if len(exec._tensor_data) == 0:
print("No external data saved")
else:
exec.write_tensor_data_to_file(".")

### LOAD AND RUN VIA PYBINDINGS.
import torch
from executorch.extension.pybindings import portable_lib
module = portable_lib._load_for_executorch("model.pte", "model.ptd")
exec_result = module.forward(test_inputs)
# Expecting key: a4f6ff98c9db8ecfe5c11e87d07f182a58cb9696f01086d9e0cdc2e986fab003
# Scale: a4f6ff98c9db8ecfe5c11e87d07f182a58cb9696f01086d9e0cdc2e986fab003
print("executorch program: ", exec_result)
print(torch.allclose(quantized_result, exec_result[0], atol=1e-3))

print("End.")
Loading