1111from executorch .backends .arm ._passes .annotate_channels_last_dim_order_pass import (
1212 AnnotateChannelsLastDimOrder ,
1313)
14+ from executorch .backends .arm ._passes .annotate_decomposed_matmul import (
15+ AnnotateDecomposedMatmulPass ,
16+ )
1417from executorch .backends .arm ._passes .cast_int64_pass import CastInt64ToInt32Pass
1518from executorch .backends .arm ._passes .conv1d_unsqueeze_pass import Conv1dUnsqueezePass
1619from executorch .backends .arm ._passes .convert_expand_copy_to_repeat import (
3235from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
3336 FoldAndAnnotateQParamsPass ,
3437 QuantizeFullArgument ,
38+ RetraceFoldedDtypesPass ,
3539)
40+ from executorch .backends .arm ._passes .insert_table_ops import InsertTableOpsPass
3641from executorch .backends .arm ._passes .keep_dims_false_to_squeeze_pass import (
3742 KeepDimsFalseToSqueezePass ,
3843)
@@ -67,24 +72,15 @@ def transform_to_backend_pipeline(
6772 self , exported_program : ExportedProgram , compile_spec : list [CompileSpec ]
6873 ):
6974 """Apply passes before transforming program to backend"""
70- self .add_pass (CastInt64ToInt32Pass ( exported_program ))
75+ self .add_pass (DecomposeLinearPass ( ))
7176 self .add_pass (RemoveGetItemPass ())
72- self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
73- self .add_pass (SizeAdjustConv2DPass ())
74- self .add_pass (RemoveClonePass ())
75- self .add_pass (ConvertExpandCopyToRepeatPass ())
7677 self .add_pass (DecomposeLayerNormPass ())
77- self .add_pass (UnsqueezeBeforeRepeatPass ())
7878 self .add_pass (DecomposeVarPass ())
7979 self .add_pass (ConvertMeanDimToAveragePool ())
8080 self .add_pass (DecomposeMeanDimPass ())
81- self .add_pass (MatchArgRanksPass (exported_program ))
82- self .add_pass (DecomposeDivPass ())
83- self .add_pass (KeepDimsFalseToSqueezePass ())
8481 self .add_pass (ConvertSplitToSlicePass ())
85- self .add_pass (Conv1dUnsqueezePass (exported_program ))
86- self .add_pass (DecomposeSoftmaxesPass ())
87- self .add_pass (DecomposeLinearPass ())
82+ # TODO MLETORCH-558
83+ self .add_pass (AnnotateDecomposedMatmulPass ())
8884 self .add_pass (QuantizeFullArgument ())
8985 self .add_pass (
9086 FoldAndAnnotateQParamsPass (
@@ -93,11 +89,49 @@ def transform_to_backend_pipeline(
9389 exir_ops .edge .aten .maximum .default ,
9490 exir_ops .edge .aten .add .Tensor ,
9591 exir_ops .edge .aten .avg_pool2d .default ,
92+ exir_ops .edge .aten .bmm .default ,
93+ exir_ops .edge .aten .cat .default ,
9694 exir_ops .edge .aten .convolution .default ,
95+ exir_ops .edge .aten .clone .default ,
96+ exir_ops .edge .aten .exp .default ,
97+ exir_ops .edge .aten .expand_copy .default ,
9798 exir_ops .edge .aten .full .default ,
99+ exir_ops .edge .aten .hardtanh .default ,
100+ exir_ops .edge .aten .log .default ,
101+ exir_ops .edge .aten .max_pool2d .default ,
102+ exir_ops .edge .aten .mm .default ,
103+ exir_ops .edge .aten .mul .Tensor ,
104+ exir_ops .edge .aten .permute_copy .default ,
105+ exir_ops .edge .aten .reciprocal .default ,
106+ exir_ops .edge .aten .relu .default ,
107+ exir_ops .edge .aten .repeat .default ,
108+ exir_ops .edge .aten .rsqrt .default ,
109+ exir_ops .edge .aten .select_copy .int ,
110+ exir_ops .edge .aten .sigmoid .default ,
111+ exir_ops .edge .aten .slice_copy .Tensor ,
112+ exir_ops .edge .aten .squeeze_copy .dims ,
113+ exir_ops .edge .aten .sub .Tensor ,
114+ exir_ops .edge .aten .sum .dim_IntList ,
115+ exir_ops .edge .aten .tanh .default ,
116+ exir_ops .edge .aten .unsqueeze_copy .default ,
117+ exir_ops .edge .aten .upsample_nearest2d .vec ,
118+ exir_ops .edge .aten .view_copy .default ,
98119 ]
99120 )
100121 )
122+ self .add_pass (RetraceFoldedDtypesPass ())
123+ self .add_pass (InsertTableOpsPass (exported_program ))
124+ self .add_pass (ConvertExpandCopyToRepeatPass ())
125+ self .add_pass (UnsqueezeBeforeRepeatPass ())
126+ self .add_pass (CastInt64ToInt32Pass (exported_program ))
127+ self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
128+ self .add_pass (SizeAdjustConv2DPass ())
129+ self .add_pass (RemoveClonePass ())
130+ self .add_pass (MatchArgRanksPass (exported_program ))
131+ self .add_pass (DecomposeDivPass ())
132+ self .add_pass (KeepDimsFalseToSqueezePass ())
133+ self .add_pass (Conv1dUnsqueezePass (exported_program ))
134+ self .add_pass (DecomposeSoftmaxesPass ())
101135 for spec in compile_spec :
102136 if spec .key == "permute_memory_format" :
103137 memory_format = spec .value .decode ()
0 commit comments