Skip to content

Commit 745d47a

Browse files
authored
Fix name of const tensors in const_prop_pass.py
Differential Revision: D66568874 Pull Request resolved: #7301
1 parent 4d2acf2 commit 745d47a

File tree

3 files changed

+94
-1
lines changed

3 files changed

+94
-1
lines changed

exir/passes/constant_prop_pass.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,22 @@ def replace_with_constant_node(
170170
exported_program: ExportedProgram,
171171
) -> tuple[torch.fx.Node, str]:
172172
# Add `prop_constant_tensor` to program.state_dict.
173-
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(exported_program.constants)}"
173+
prefix = "_prop_tensor_constant"
174+
prop_constant_tensor_fqn = f"{prefix}{len(exported_program.constants)}"
175+
# If prop_constant_tensor_fqn already exists in the state dict, we need
176+
# to create a new name. Find the largest suffix of "_prop_tensor_constant",
177+
# and increment it by 1 to form the new name.
178+
if prop_constant_tensor_fqn in exported_program.constants:
179+
suffix = 1 + max(
180+
(
181+
int(name[len(prefix) :])
182+
for name in exported_program.constants.keys()
183+
if name.startswith(prefix) and name[len(prefix) :].isdigit()
184+
),
185+
default=-1,
186+
)
187+
prop_constant_tensor_fqn = f"{prefix}{suffix}"
188+
174189
exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor
175190

176191
# Insert a new placeholder node for the propagated constant tensor.

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ python_unittest(
221221
"//executorch/exir/passes:sym_to_tensor_pass",
222222
"//executorch/exir/program:program",
223223
"//executorch/extension/pybindings:portable_lib", # @manual
224+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
224225
],
225226
)
226227

exir/tests/test_passes.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# Import passes
1717
import executorch.exir.memory_planning # noqa
1818
import torch
19+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
1920
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge
2021
from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
2122
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -65,10 +66,12 @@
6566
from torch import nn
6667

6768
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
69+
from torch.ao.quantization.quantizer import QuantizationSpec
6870
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
6971
get_symmetric_quantization_config,
7072
XNNPACKQuantizer,
7173
)
74+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
7275
from torch.export import export
7376
from torch.export.graph_signature import InputKind, InputSpec, TensorArgument
7477
from torch.fx import GraphModule, subgraph_rewriter
@@ -1238,6 +1241,80 @@ def forward(self, x):
12381241
],
12391242
)
12401243

1244+
def test_constant_prop_pass_after_delegation(self) -> None:
1245+
class M(torch.nn.Module):
1246+
def __init__(self, dim=32):
1247+
super().__init__()
1248+
self.linear = torch.nn.Linear(dim, dim)
1249+
1250+
def forward(self, query, key, value):
1251+
query = self.linear(query)
1252+
key = self.linear(key)
1253+
value = self.linear(value)
1254+
return torch.nn.functional.scaled_dot_product_attention(
1255+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True
1256+
)
1257+
1258+
query = torch.randn(32, 32, 32, 32)
1259+
key = torch.randn(32, 32, 32, 32)
1260+
value = torch.randn(32, 32, 32, 32)
1261+
1262+
# Capture the model
1263+
m = torch.export.export_for_training(M(32), (query, key, value)).module()
1264+
1265+
# 8w16a quantization
1266+
from torch.ao.quantization.observer import (
1267+
MinMaxObserver,
1268+
PerChannelMinMaxObserver,
1269+
)
1270+
1271+
activation_qspec = QuantizationSpec(
1272+
dtype=torch.int16,
1273+
quant_min=-32768,
1274+
quant_max=32767,
1275+
qscheme=torch.per_tensor_affine,
1276+
is_dynamic=False,
1277+
observer_or_fake_quant_ctr=MinMaxObserver,
1278+
)
1279+
weight_qspec = QuantizationSpec(
1280+
dtype=torch.int8,
1281+
quant_min=-128,
1282+
quant_max=127,
1283+
qscheme=torch.per_channel_symmetric,
1284+
ch_axis=0,
1285+
is_dynamic=False,
1286+
observer_or_fake_quant_ctr=PerChannelMinMaxObserver,
1287+
)
1288+
custom_qconfig = QuantizationConfig(
1289+
input_activation=activation_qspec,
1290+
output_activation=activation_qspec,
1291+
weight=weight_qspec,
1292+
bias=None,
1293+
is_qat=False,
1294+
)
1295+
quantizer = XNNPACKQuantizer()
1296+
quantizer.set_global(custom_qconfig)
1297+
m = prepare_pt2e(m, quantizer) # pyre-fixme[6]
1298+
m = convert_pt2e(m)
1299+
1300+
# export, perform constant propagation to make weights const
1301+
aten_prog = export(m, (query, key, value))
1302+
aten_prog = constant_prop_pass(aten_prog)
1303+
1304+
# lower to edge dialect
1305+
edge_prog = to_edge(
1306+
aten_prog,
1307+
compile_config=EdgeCompileConfig(
1308+
_check_ir_validity=False, _use_edge_ops=True
1309+
),
1310+
)
1311+
edge_prog = edge_prog.to_backend(XnnpackPartitioner())
1312+
1313+
# Perform constant propagation on the decomposed ops from sdpa
1314+
aten_prog = constant_prop_pass(edge_prog.exported_program())
1315+
# There should be at least one const due to spda op
1316+
self.assertGreaterEqual(len(aten_prog.constants), 1)
1317+
12411318
def test_constant_prop_pass_for_parameter_slice(self) -> None:
12421319
def count_slice(gm: torch.fx.GraphModule) -> int:
12431320
return sum(

0 commit comments

Comments
 (0)