Skip to content

Commit c09bdb3

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Fix name of const tensors in const_prop_pass.py (#7301)
Summary: After XNNPack delegation, for llama predictor, some const tensors in the exported program are eliminated, e.g., c__prop_tensor_constant32 in P1688172607. If we run constant_prop_pass after this point, the const tensor naming logic in the implementation is incorrect, and could create redefinition of const tensors (error in P1687944413). This diff fixes the naming logic of const tensors in the pass. Differential Revision: D66568874
1 parent 3f7eb3b commit c09bdb3

File tree

3 files changed

+94
-1
lines changed

3 files changed

+94
-1
lines changed

exir/passes/constant_prop_pass.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,22 @@ def replace_with_constant_node(
170170
exported_program: ExportedProgram,
171171
) -> tuple[torch.fx.Node, str]:
172172
# Add `prop_constant_tensor` to program.state_dict.
173-
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(exported_program.constants)}"
173+
prefix = "_prop_tensor_constant"
174+
prop_constant_tensor_fqn = f"{prefix}{len(exported_program.constants)}"
175+
# If prop_constant_tensor_fqn already exists in the state dict, we need
176+
# to create a new name. Find the largest suffix of "_prop_tensor_constant",
177+
# and increment it by 1 to form the new name.
178+
if prop_constant_tensor_fqn in exported_program.constants:
179+
suffix = 1 + max(
180+
(
181+
int(name[len(prefix) :])
182+
for name in exported_program.constants.keys()
183+
if name.startswith(prefix) and name[len(prefix) :].isdigit()
184+
),
185+
default=-1,
186+
)
187+
prop_constant_tensor_fqn = f"{prefix}{suffix}"
188+
174189
exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor
175190

176191
# Insert a new placeholder node for the propagated constant tensor.

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ python_unittest(
221221
"//executorch/exir/passes:sym_to_tensor_pass",
222222
"//executorch/exir/program:program",
223223
"//executorch/extension/pybindings:portable_lib", # @manual
224+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
224225
],
225226
)
226227

exir/tests/test_passes.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# Import passes
1717
import executorch.exir.memory_planning # noqa
1818
import torch
19+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
1920
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge
2021
from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
2122
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -65,10 +66,12 @@
6566
from torch import nn
6667

6768
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
69+
from torch.ao.quantization.quantizer import QuantizationSpec
6870
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
6971
get_symmetric_quantization_config,
7072
XNNPACKQuantizer,
7173
)
74+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
7275
from torch.export import export
7376
from torch.export.graph_signature import InputKind, InputSpec, TensorArgument
7477
from torch.fx import GraphModule, subgraph_rewriter
@@ -1238,6 +1241,80 @@ def forward(self, x):
12381241
],
12391242
)
12401243

1244+
def test_constant_prop_pass_after_delegation(self) -> None:
1245+
class M(torch.nn.Module):
1246+
def __init__(self, dim=32):
1247+
super().__init__()
1248+
self.linear = torch.nn.Linear(dim, dim)
1249+
1250+
def forward(self, query, key, value):
1251+
query = self.linear(query)
1252+
key = self.linear(key)
1253+
value = self.linear(value)
1254+
return torch.nn.functional.scaled_dot_product_attention(
1255+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True
1256+
)
1257+
1258+
query = torch.randn(32, 32, 32, 32)
1259+
key = torch.randn(32, 32, 32, 32)
1260+
value = torch.randn(32, 32, 32, 32)
1261+
1262+
# Capture the model
1263+
m = torch.export.export_for_training(M(32), (query, key, value)).module()
1264+
1265+
# 8w16a quantization
1266+
from torch.ao.quantization.observer import (
1267+
MinMaxObserver,
1268+
PerChannelMinMaxObserver,
1269+
)
1270+
1271+
activation_qspec = QuantizationSpec(
1272+
dtype=torch.int16,
1273+
quant_min=-32768,
1274+
quant_max=32767,
1275+
qscheme=torch.per_tensor_affine,
1276+
is_dynamic=False,
1277+
observer_or_fake_quant_ctr=MinMaxObserver,
1278+
)
1279+
weight_qspec = QuantizationSpec(
1280+
dtype=torch.int8,
1281+
quant_min=-128,
1282+
quant_max=127,
1283+
qscheme=torch.per_channel_symmetric,
1284+
ch_axis=0,
1285+
is_dynamic=False,
1286+
observer_or_fake_quant_ctr=PerChannelMinMaxObserver,
1287+
)
1288+
custom_qconfig = QuantizationConfig(
1289+
input_activation=activation_qspec,
1290+
output_activation=activation_qspec,
1291+
weight=weight_qspec,
1292+
bias=None,
1293+
is_qat=False,
1294+
)
1295+
quantizer = XNNPACKQuantizer()
1296+
quantizer.set_global(custom_qconfig)
1297+
m = prepare_pt2e(m, quantizer) # pyre-fixme[6]
1298+
m = convert_pt2e(m)
1299+
1300+
# export, perform constant propagation to make weights const
1301+
aten_prog = export(m, (query, key, value))
1302+
aten_prog = constant_prop_pass(aten_prog)
1303+
1304+
# lower to edge dialect
1305+
edge_prog = to_edge(
1306+
aten_prog,
1307+
compile_config=EdgeCompileConfig(
1308+
_check_ir_validity=False, _use_edge_ops=True
1309+
),
1310+
)
1311+
edge_prog = edge_prog.to_backend(XnnpackPartitioner())
1312+
1313+
# Perform constant propagation on the decomposed ops from sdpa
1314+
aten_prog = constant_prop_pass(edge_prog.exported_program())
1315+
# There should be at least one const due to spda op
1316+
self.assertGreaterEqual(len(aten_prog.constants), 1)
1317+
12411318
def test_constant_prop_pass_for_parameter_slice(self) -> None:
12421319
def count_slice(gm: torch.fx.GraphModule) -> int:
12431320
return sum(

0 commit comments

Comments
 (0)