Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,22 @@ def replace_with_constant_node(
exported_program: ExportedProgram,
) -> tuple[torch.fx.Node, str]:
# Add `prop_constant_tensor` to program.state_dict.
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(exported_program.constants)}"
prefix = "_prop_tensor_constant"
prop_constant_tensor_fqn = f"{prefix}{len(exported_program.constants)}"
# If prop_constant_tensor_fqn already exists in the state dict, we need
# to create a new name. Find the largest suffix of "_prop_tensor_constant",
# and increment it by 1 to form the new name.
if prop_constant_tensor_fqn in exported_program.constants:
suffix = 1 + max(
(
int(name[len(prefix) :])
for name in exported_program.constants.keys()
if name.startswith(prefix) and name[len(prefix) :].isdigit()
),
default=-1,
)
prop_constant_tensor_fqn = f"{prefix}{suffix}"

exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor

# Insert a new placeholder node for the propagated constant tensor.
Expand Down
1 change: 1 addition & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ python_unittest(
"//executorch/exir/passes:sym_to_tensor_pass",
"//executorch/exir/program:program",
"//executorch/extension/pybindings:portable_lib", # @manual
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
],
)

Expand Down
77 changes: 77 additions & 0 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Import passes
import executorch.exir.memory_planning # noqa
import torch
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge
from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
Expand Down Expand Up @@ -65,10 +66,12 @@
from torch import nn

from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
from torch.export import export
from torch.export.graph_signature import InputKind, InputSpec, TensorArgument
from torch.fx import GraphModule, subgraph_rewriter
Expand Down Expand Up @@ -1238,6 +1241,80 @@ def forward(self, x):
],
)

def test_constant_prop_pass_after_delegation(self) -> None:
class M(torch.nn.Module):
def __init__(self, dim=32):
super().__init__()
self.linear = torch.nn.Linear(dim, dim)

def forward(self, query, key, value):
query = self.linear(query)
key = self.linear(key)
value = self.linear(value)
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True
)

query = torch.randn(32, 32, 32, 32)
key = torch.randn(32, 32, 32, 32)
value = torch.randn(32, 32, 32, 32)

# Capture the model
m = torch.export.export_for_training(M(32), (query, key, value)).module()

# 8w16a quantization
from torch.ao.quantization.observer import (
MinMaxObserver,
PerChannelMinMaxObserver,
)

activation_qspec = QuantizationSpec(
dtype=torch.int16,
quant_min=-32768,
quant_max=32767,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=MinMaxObserver,
)
weight_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=PerChannelMinMaxObserver,
)
custom_qconfig = QuantizationConfig(
input_activation=activation_qspec,
output_activation=activation_qspec,
weight=weight_qspec,
bias=None,
is_qat=False,
)
quantizer = XNNPACKQuantizer()
quantizer.set_global(custom_qconfig)
m = prepare_pt2e(m, quantizer) # pyre-fixme[6]
m = convert_pt2e(m)

# export, perform constant propagation to make weights const
aten_prog = export(m, (query, key, value))
aten_prog = constant_prop_pass(aten_prog)

# lower to edge dialect
edge_prog = to_edge(
aten_prog,
compile_config=EdgeCompileConfig(
_check_ir_validity=False, _use_edge_ops=True
),
)
edge_prog = edge_prog.to_backend(XnnpackPartitioner())

# Perform constant propagation on the decomposed ops from sdpa
aten_prog = constant_prop_pass(edge_prog.exported_program())
# There should be at least one const due to spda op
self.assertGreaterEqual(len(aten_prog.constants), 1)

def test_constant_prop_pass_for_parameter_slice(self) -> None:
def count_slice(gm: torch.fx.GraphModule) -> int:
return sum(
Expand Down