|
16 | 16 | # Import passes
|
17 | 17 | import executorch.exir.memory_planning # noqa
|
18 | 18 | import torch
|
| 19 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
19 | 20 | from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge
|
20 | 21 | from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
|
21 | 22 | from executorch.exir.dialects.edge._ops import EdgeOpOverload
|
|
65 | 66 | from torch import nn
|
66 | 67 |
|
67 | 68 | from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
| 69 | +from torch.ao.quantization.quantizer import QuantizationSpec |
68 | 70 | from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
69 | 71 | get_symmetric_quantization_config,
|
70 | 72 | XNNPACKQuantizer,
|
71 | 73 | )
|
| 74 | +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig |
72 | 75 | from torch.export import export
|
73 | 76 | from torch.export.graph_signature import InputKind, InputSpec, TensorArgument
|
74 | 77 | from torch.fx import GraphModule, subgraph_rewriter
|
@@ -1238,6 +1241,80 @@ def forward(self, x):
|
1238 | 1241 | ],
|
1239 | 1242 | )
|
1240 | 1243 |
|
| 1244 | + def test_constant_prop_pass_after_delegation(self) -> None: |
| 1245 | + class M(torch.nn.Module): |
| 1246 | + def __init__(self, dim=32): |
| 1247 | + super().__init__() |
| 1248 | + self.linear = torch.nn.Linear(dim, dim) |
| 1249 | + |
| 1250 | + def forward(self, query, key, value): |
| 1251 | + query = self.linear(query) |
| 1252 | + key = self.linear(key) |
| 1253 | + value = self.linear(value) |
| 1254 | + return torch.nn.functional.scaled_dot_product_attention( |
| 1255 | + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True |
| 1256 | + ) |
| 1257 | + |
| 1258 | + query = torch.randn(32, 32, 32, 32) |
| 1259 | + key = torch.randn(32, 32, 32, 32) |
| 1260 | + value = torch.randn(32, 32, 32, 32) |
| 1261 | + |
| 1262 | + # Capture the model |
| 1263 | + m = torch.export.export_for_training(M(32), (query, key, value)).module() |
| 1264 | + |
| 1265 | + # 8w16a quantization |
| 1266 | + from torch.ao.quantization.observer import ( |
| 1267 | + MinMaxObserver, |
| 1268 | + PerChannelMinMaxObserver, |
| 1269 | + ) |
| 1270 | + |
| 1271 | + activation_qspec = QuantizationSpec( |
| 1272 | + dtype=torch.int16, |
| 1273 | + quant_min=-32768, |
| 1274 | + quant_max=32767, |
| 1275 | + qscheme=torch.per_tensor_affine, |
| 1276 | + is_dynamic=False, |
| 1277 | + observer_or_fake_quant_ctr=MinMaxObserver, |
| 1278 | + ) |
| 1279 | + weight_qspec = QuantizationSpec( |
| 1280 | + dtype=torch.int8, |
| 1281 | + quant_min=-128, |
| 1282 | + quant_max=127, |
| 1283 | + qscheme=torch.per_channel_symmetric, |
| 1284 | + ch_axis=0, |
| 1285 | + is_dynamic=False, |
| 1286 | + observer_or_fake_quant_ctr=PerChannelMinMaxObserver, |
| 1287 | + ) |
| 1288 | + custom_qconfig = QuantizationConfig( |
| 1289 | + input_activation=activation_qspec, |
| 1290 | + output_activation=activation_qspec, |
| 1291 | + weight=weight_qspec, |
| 1292 | + bias=None, |
| 1293 | + is_qat=False, |
| 1294 | + ) |
| 1295 | + quantizer = XNNPACKQuantizer() |
| 1296 | + quantizer.set_global(custom_qconfig) |
| 1297 | + m = prepare_pt2e(m, quantizer) # pyre-fixme[6] |
| 1298 | + m = convert_pt2e(m) |
| 1299 | + |
| 1300 | + # export, perform constant propagation to make weights const |
| 1301 | + aten_prog = export(m, (query, key, value)) |
| 1302 | + aten_prog = constant_prop_pass(aten_prog) |
| 1303 | + |
| 1304 | + # lower to edge dialect |
| 1305 | + edge_prog = to_edge( |
| 1306 | + aten_prog, |
| 1307 | + compile_config=EdgeCompileConfig( |
| 1308 | + _check_ir_validity=False, _use_edge_ops=True |
| 1309 | + ), |
| 1310 | + ) |
| 1311 | + edge_prog = edge_prog.to_backend(XnnpackPartitioner()) |
| 1312 | + |
| 1313 | + # Perform constant propagation on the decomposed ops from sdpa |
| 1314 | + aten_prog = constant_prop_pass(edge_prog.exported_program()) |
| 1315 | + # There should be at least one const due to spda op |
| 1316 | + self.assertGreaterEqual(len(aten_prog.constants), 1) |
| 1317 | + |
1241 | 1318 | def test_constant_prop_pass_for_parameter_slice(self) -> None:
|
1242 | 1319 | def count_slice(gm: torch.fx.GraphModule) -> int:
|
1243 | 1320 | return sum(
|
|
0 commit comments