From a11afab70805e133c3274819ba44349c993e9809 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Mon, 2 Dec 2024 09:14:38 -0800 Subject: [PATCH] Run decompositions before the quantizer (#7111) Summary: In the current flow, decompositions run in `to_edge()`, long after the quantization process is done. This creates a lot of issues, since we cannot quantize any operations contained in the large operators that the graph tracer can give (e.g. aten.scaled_dot_product_attention, aten.rnn_.input, and a few others). Any models using those will see many fp32 operators in the final graph. Running the decomps earlier solves the problem, but we need to retain a couple operators that we do rely on in the quantizer, like `aten.linear`, `aten.conv1d` and `aten.conv2d`. Reviewed By: zonglinpeng Differential Revision: D66461406 --- backends/cadence/aot/compiler.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 937e3e39bc1..6b3a023181c 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -28,6 +28,7 @@ to_edge, ) from executorch.exir.pass_base import PassResult +from torch._inductor.decomposition import remove_decompositions from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -58,16 +59,33 @@ def convert_pt2( Returns a GraphModule with the converted model. """ + # Get default decompositions + decomp_table = torch.export.default_decompositions() + # Select ops to keep + ops_to_keep = [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.layer_norm.default, + torch.ops.aten.linear.default, + torch.ops.aten.matmul.default, + ] + # Remove decompositions for the ops we want to keep + # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any + remove_decompositions(decomp_table, ops_to_keep) # Export with dynamo - model_gm = torch.export.export_for_training(model, inputs).module() + model_gm = ( + torch.export.export_for_training(model, inputs) + .run_decompositions(decomp_table) + .module() + ) - if model_gm_has_SDPA(model_gm): # pyre-fixme[6] + if model_gm_has_SDPA(model_gm): # Decompose SDPA - DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6] + DecomposeScaledDotProductAttention(False)(model_gm) # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882 # for details). - result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6] + result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) assert result is not None model_gm = result.graph_module