@@ -795,9 +795,19 @@ def _generate_edge_program(
795795 name : str ,
796796 config : EdgeCompileConfig ,
797797 program : ExportedProgram ,
798- ops_set_to_not_decompose : Optional [List [torch ._ops .OpOverload ]] = None ,
798+ core_aten_ops_exception_list : Optional [List [torch ._ops .OpOverload ]] = None ,
799+ preserve_ops : Optional [List [torch ._ops .OpOverload ]] = None ,
799800) -> ExportedProgram :
800-
801+ """
802+ Args:
803+ name: The name of the program.
804+ config: The configuration for the edge program.
805+ program: The exported program to be converted to an edge program.
806+ core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
807+ preserve_ops: A list of aten ops that should not be decomposed.
808+ Returns:
809+ An ExportedProgram in edge dialect.
810+ """
801811 # Remove invalid assert ops, such as _assert_tensor_metadata
802812 gm = program .graph_module
803813 gm_res = RemoveNonCoreAtenOpGraphAssertsPass ()(gm )
@@ -812,7 +822,8 @@ def _generate_edge_program(
812822 EXIRATenDialectVerifier (
813823 edge_compile_config = config ,
814824 class_only = False ,
815- exception_list = ops_set_to_not_decompose ,
825+ core_aten_ops_exception_list = core_aten_ops_exception_list ,
826+ preserve_ops = preserve_ops ,
816827 )(gm )
817828 except ExportError as e :
818829 logging .info (f"Input program { name } is not in ATen dialect." )
@@ -848,7 +859,8 @@ def _generate_edge_program(
848859 EXIREdgeDialectVerifier (
849860 edge_compile_config = config ,
850861 class_only = True ,
851- exception_list = ops_set_to_not_decompose ,
862+ core_aten_ops_exception_list = core_aten_ops_exception_list ,
863+ preserve_ops = preserve_ops ,
852864 )
853865 ],
854866 )
@@ -864,7 +876,7 @@ def _replace_aten_ops_with_transformed_ops(
864876 program : ExportedProgram ,
865877 partitioner ,
866878):
867- ops_to_not_decompose = set ()
879+ preserve_ops = set ()
868880 partitioners = partitioner .get (name )
869881 if partitioners is None :
870882 return
@@ -889,7 +901,7 @@ def _replace_aten_ops_with_transformed_ops(
889901 and node .target in ops_set_to_not_decompose
890902 and is_op_supported
891903 ):
892- ops_to_not_decompose .add (node .target )
904+ preserve_ops .add (node .target )
893905 node .target = aten_op_to_transform_op [node .target ]
894906
895907 for _ , submod , _ in get_control_flow_submodules (program .graph_module ):
@@ -900,10 +912,10 @@ def _replace_aten_ops_with_transformed_ops(
900912 and node .target in ops_set_to_not_decompose
901913 and is_op_supported
902914 ):
903- ops_to_not_decompose .add (node .target )
915+ preserve_ops .add (node .target )
904916 node .target = aten_op_to_transform_op [node .target ]
905917
906- return ops_to_not_decompose
918+ return preserve_ops
907919
908920
909921def _restore_transformed_ops_to_aten_ops (program : ExportedProgram ):
@@ -1014,7 +1026,7 @@ def _sanity_check_graph_for_non_decomp_ops(
10141026
10151027
10161028def _remove_invalid_ops_for_not_decompose (
1017- ops_to_not_decompose : List [torch ._ops .OpOverload ],
1029+ preserve_ops : List [torch ._ops .OpOverload ],
10181030) -> List [torch ._ops .OpOverload ]:
10191031 _logged_warnings = set ()
10201032
@@ -1079,7 +1091,7 @@ def keep(op):
10791091 return False
10801092 return True
10811093
1082- return list (filter (keep , ops_to_not_decompose ))
1094+ return list (filter (keep , preserve_ops ))
10831095
10841096
10851097def _gen_edge_manager_for_partitioners (
@@ -1136,7 +1148,7 @@ def _gen_edge_manager_for_partitioners(
11361148 name ,
11371149 config ,
11381150 program ,
1139- list (ops_set_to_not_decompose_by_program .get (name , [])),
1151+ preserve_ops = list (ops_set_to_not_decompose_by_program .get (name , [])),
11401152 )
11411153
11421154 edge_manager = EdgeProgramManager (
@@ -1281,7 +1293,7 @@ def to_edge_transform_and_lower(
12811293 EXIREdgeDialectVerifier (
12821294 edge_compile_config = config ,
12831295 class_only = True ,
1284- exception_list = list (ops_set_to_not_decompose ),
1296+ preserve_ops = list (ops_set_to_not_decompose ),
12851297 )()(program .graph_module )
12861298
12871299 return edge_manager
@@ -1328,7 +1340,7 @@ def to_edge_with_preserved_ops(
13281340 table .pop (op , None )
13291341 program = program .run_decompositions (table )
13301342 edge_programs [name ] = _generate_edge_program (
1331- name , config , program , list (preserve_ops )
1343+ name , config , program , preserve_ops = list (preserve_ops )
13321344 )
13331345
13341346 return EdgeProgramManager (
@@ -1367,8 +1379,16 @@ def to_edge(
13671379
13681380 for name , program in aten_programs .items ():
13691381 # Decompose to Core ATen
1370- program = program .run_decompositions (_default_decomposition_table ())
1371- edge_programs [name ] = _generate_edge_program (name , config , program )
1382+ table = _default_decomposition_table ()
1383+ preserve_ops = []
1384+ if compile_config :
1385+ preserve_ops = compile_config ._preserve_ops
1386+ for op in compile_config ._preserve_ops :
1387+ table .pop (op , None )
1388+ program = program .run_decompositions (table )
1389+ edge_programs [name ] = _generate_edge_program (
1390+ name , config , program , preserve_ops = preserve_ops
1391+ )
13721392
13731393 return EdgeProgramManager (edge_programs , constant_methods , config )
13741394
@@ -1389,7 +1409,8 @@ def __init__(
13891409 edge_programs : Union [ExportedProgram , Dict [str , ExportedProgram ]],
13901410 constant_methods : Optional [Dict [str , Any ]] = None ,
13911411 compile_config : Optional [EdgeCompileConfig ] = None ,
1392- ops_set_to_not_decompose : Optional [List [torch ._ops .OpOverload ]] = None ,
1412+ core_aten_ops_exception_list : Optional [List [torch ._ops .OpOverload ]] = None ,
1413+ preserve_ops : Optional [List [torch ._ops .OpOverload ]] = None ,
13931414 ):
13941415 """
13951416 Should not be called directly by users. User should use :func:'to_edge' instead.
@@ -1404,7 +1425,8 @@ def __init__(
14041425 try :
14051426 EXIREdgeDialectVerifier (
14061427 edge_compile_config = self .compile_config ,
1407- exception_list = ops_set_to_not_decompose ,
1428+ core_aten_ops_exception_list = core_aten_ops_exception_list ,
1429+ preserve_ops = preserve_ops ,
14081430 )(program .graph_module )
14091431 except ExportError as e :
14101432 logging .info (f"Input program { name } is not in aten dialect." )
0 commit comments