| 
28 | 28 |     to_edge,  | 
29 | 29 | )  | 
30 | 30 | from executorch.exir.pass_base import PassResult  | 
 | 31 | +from torch._inductor.decomposition import remove_decompositions  | 
31 | 32 | from torch.ao.quantization.pt2e.export_utils import model_is_exported  | 
32 | 33 | from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e  | 
33 | 34 | 
 
  | 
@@ -58,16 +59,33 @@ def convert_pt2(  | 
58 | 59 |     Returns a GraphModule with the converted model.  | 
59 | 60 |     """  | 
60 | 61 | 
 
  | 
 | 62 | +    # Get default decompositions  | 
 | 63 | +    decomp_table = torch.export.default_decompositions()  | 
 | 64 | +    # Select ops to keep  | 
 | 65 | +    ops_to_keep = [  | 
 | 66 | +        torch.ops.aten.conv1d.default,  | 
 | 67 | +        torch.ops.aten.conv2d.default,  | 
 | 68 | +        torch.ops.aten.layer_norm.default,  | 
 | 69 | +        torch.ops.aten.linear.default,  | 
 | 70 | +        torch.ops.aten.matmul.default,  | 
 | 71 | +    ]  | 
 | 72 | +    # Remove decompositions for the ops we want to keep  | 
 | 73 | +    # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any  | 
 | 74 | +    remove_decompositions(decomp_table, ops_to_keep)  | 
61 | 75 |     # Export with dynamo  | 
62 |  | -    model_gm = torch.export.export_for_training(model, inputs).module()  | 
 | 76 | +    model_gm = (  | 
 | 77 | +        torch.export.export_for_training(model, inputs)  | 
 | 78 | +        .run_decompositions(decomp_table)  | 
 | 79 | +        .module()  | 
 | 80 | +    )  | 
63 | 81 | 
 
  | 
64 |  | -    if model_gm_has_SDPA(model_gm):  # pyre-fixme[6]  | 
 | 82 | +    if model_gm_has_SDPA(model_gm):  | 
65 | 83 |         # Decompose SDPA  | 
66 |  | -        DecomposeScaledDotProductAttention(False)(model_gm)  # pyre-fixme[6]  | 
 | 84 | +        DecomposeScaledDotProductAttention(False)(model_gm)  | 
67 | 85 | 
 
  | 
68 | 86 |         # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882  | 
69 | 87 |         # for details).  | 
70 |  | -        result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)  # pyre-fixme[6]  | 
 | 88 | +        result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)  | 
71 | 89 |         assert result is not None  | 
72 | 90 |         model_gm = result.graph_module  | 
73 | 91 | 
 
  | 
 | 
0 commit comments