3131 EdgeProgramManager ,
3232 ExecutorchBackendConfig ,
3333 ExecutorchProgramManager ,
34- to_edge ,
3534)
3635from executorch .exir .pass_base import PassResult
3736from executorch .exir .passes import ToOutVarPass
3837from executorch .exir .passes .sym_shape_eval_pass import HintBasedSymShapeEvalPass
38+ from executorch .exir .program ._program import to_edge_with_preserved_ops
3939from torch ._inductor .decomposition import remove_decompositions
4040from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
4141
@@ -80,6 +80,7 @@ def convert_pt2(
8080 torch .ops .aten .layer_norm .default ,
8181 torch .ops .aten .linear .default ,
8282 torch .ops .aten .matmul .default ,
83+ torch .ops .aten .rms_norm .default ,
8384 ]
8485 # Remove decompositions for the ops we want to keep
8586 # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
@@ -201,9 +202,9 @@ def lower_ep_to_edge(
201202 """
202203 Lower an ExportedProgram to an EdgeProgramManager (in edge IR).
203204 """
204- # Call to_edge to convert the graph to edge IR.
205+ # Call to_edge_with_preserved_ops to convert the graph to edge IR.
205206 # Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
206- edge_prog_manager = to_edge (
207+ edge_prog_manager = to_edge_with_preserved_ops (
207208 expo_program ,
208209 compile_config = EdgeCompileConfig (
209210 _skip_dim_order = True ,
@@ -216,9 +217,11 @@ def lower_ep_to_edge(
216217 torch .ops .aten .linalg_vector_norm .default ,
217218 torch .ops .aten .unfold .default ,
218219 torch .ops .aten .angle .default ,
220+ torch .ops .aten .rms_norm .default ,
219221 ],
220222 ),
221223 constant_methods = constant_methods ,
224+ preserve_ops = (torch .ops .aten .rms_norm .default ,),
222225 )
223226
224227 if dump_graphs :
0 commit comments