11
11
import io
12
12
import logging
13
13
import os
14
+ from collections import defaultdict
14
15
from typing import Any , Dict , List , Optional , Sequence , Set , TextIO , Type , Union
15
16
16
17
import torch
@@ -1136,25 +1137,16 @@ def keep(op):
1136
1137
1137
1138
1138
1139
def _can_skip_using_EDGE_DO_NOT_DECOMP (
1139
- partitioner : Dict [ str , List [ Partitioner ]], aten_programs : Dict [ str , ExportedProgram ]
1140
+ partitioner : Partitioner , program : ExportedProgram
1140
1141
) -> bool :
1141
1142
# THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
1142
1143
# has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
1143
1144
# fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
1144
1145
# and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
1145
1146
# EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
1146
1147
# As a temp fix, we give a more reliable path for backends that do not specify check_op_support
1147
- can_skip_using_EDGE_DO_NOT_DECOMP = True
1148
- for name , program in aten_programs .items ():
1149
- if partitioner is not None :
1150
- for curr_partitioner in partitioner .get (name , []):
1151
- (
1152
- curr_ops_no_decomp ,
1153
- check_op_support ,
1154
- ) = curr_partitioner .ops_to_not_decompose (program )
1155
- if check_op_support is not None :
1156
- can_skip_using_EDGE_DO_NOT_DECOMP = False
1157
- return can_skip_using_EDGE_DO_NOT_DECOMP
1148
+ _ , check_op_support = partitioner .ops_to_not_decompose (program )
1149
+ return check_op_support is None
1158
1150
1159
1151
1160
1152
def _gen_edge_manager_for_partitioners (
@@ -1177,60 +1169,75 @@ def _gen_edge_manager_for_partitioners(
1177
1169
on nodes with preserved aten targets. They are then replaces with transformed ops to
1178
1170
keep them through the second pass of decompositions
1179
1171
"""
1180
- can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1181
- partitioner , aten_programs
1182
- )
1183
- ops_set_to_not_decompose_by_program = {}
1172
+ ops_set_to_not_decompose_by_program = defaultdict (list )
1184
1173
edge_programs : Dict [str , ExportedProgram ] = {}
1185
1174
for name , program in aten_programs .items ():
1186
1175
# Functionalize program before asking partitioners to preserve ops
1187
1176
program = program .run_decompositions ({})
1188
1177
1189
1178
if partitioner is not None :
1190
- # preserve all ops listed by all partitioners first
1191
- all_ops_no_decomp = set ()
1192
- all_ops_no_decomp_needing_preservation = []
1193
- for curr_partitioner in partitioner .get (name , []):
1179
+ partitioners_for_program = partitioner .get (name , [])
1180
+ final_ops_to_preserve = set ()
1181
+
1182
+ # Decompose by default if there are no partitioners for the method
1183
+ if not partitioners_for_program :
1184
+ program = program .run_decompositions (_default_decomposition_table ())
1185
+
1186
+ # Process each partitioner individually using their specific requirements
1187
+ for curr_partitioner in partitioners_for_program :
1194
1188
curr_ops_no_decomp , _ = curr_partitioner .ops_to_not_decompose (program )
1195
- all_ops_no_decomp |= set (curr_ops_no_decomp )
1196
1189
1197
- # If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1198
- # Otherwise there will be issues
1199
- if not can_skip_using_EDGE_DO_NOT_DECOMP :
1200
- all_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1201
- list (all_ops_no_decomp )
1202
- )
1203
- all_ops_no_decomp = set (all_ops_no_decomp )
1204
-
1205
- # Run default decompositions, except for those in all_ops_no_decomp
1206
- table = _default_decomposition_table ()
1207
- for op in all_ops_no_decomp :
1208
- if table .pop (op , None ) is not None :
1209
- all_ops_no_decomp_needing_preservation .append (op )
1210
- program = program .run_decompositions (table )
1211
-
1212
- # Among all the preserved aten ops, use the check_op_fn to do an additional
1213
- # check on which ops need to be preserved and which ops need to be decomposed
1214
- # Those which are truly preserved will be replaced with transformed ops
1215
- if can_skip_using_EDGE_DO_NOT_DECOMP :
1216
- ops_set_to_not_decompose_by_program [name ] = (
1217
- all_ops_no_decomp_needing_preservation
1218
- )
1219
- else :
1220
- ops_set_to_not_decompose_by_program [name ] = (
1221
- _replace_aten_ops_with_transformed_ops (name , program , partitioner )
1222
- or []
1190
+ # Check if this partitioner can skip using EDGE_DO_NOT_DECOMP
1191
+ can_skip_using_edge_do_not_decomp = _can_skip_using_EDGE_DO_NOT_DECOMP (
1192
+ curr_partitioner , program
1223
1193
)
1224
1194
1225
- if not can_skip_using_EDGE_DO_NOT_DECOMP :
1226
- program = program .run_decompositions (_default_decomposition_table ())
1227
- _restore_transformed_ops_to_aten_ops (program )
1195
+ if can_skip_using_edge_do_not_decomp :
1196
+ # Preserve all ops in curr_ops_no_decomp from decomposition
1197
+ table = _default_decomposition_table ()
1198
+ ops_needing_preservation = []
1199
+
1200
+ for op in curr_ops_no_decomp :
1201
+ if table .pop (op , None ) is not None :
1202
+ ops_needing_preservation .append (op )
1203
+
1204
+ program = program .run_decompositions (table )
1205
+ final_ops_to_preserve .update (ops_needing_preservation )
1206
+ else :
1207
+ # EDGE_DO_NOT_DECOMP path for the partitioner
1208
+ curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1209
+ curr_ops_no_decomp
1210
+ )
1211
+
1212
+ # Apply decompositions with this partitioner's preserved ops
1213
+ table = _default_decomposition_table ()
1214
+ for op in curr_ops_no_decomp :
1215
+ table .pop (op , None )
1216
+
1217
+ # First pass of decompositions with this partitioner's preserved ops
1218
+ program = program .run_decompositions (table )
1219
+
1220
+ # Filter ops using EDGE_DO_NOT_DECOMP
1221
+ temp_partitioner_dict = {name : [curr_partitioner ]}
1222
+ preserved_ops = (
1223
+ _replace_aten_ops_with_transformed_ops (
1224
+ name , program , temp_partitioner_dict
1225
+ )
1226
+ or []
1227
+ )
1228
+ final_ops_to_preserve .update (preserved_ops )
1229
+
1230
+ # Second pass of decompositions with this partitioner's preserved ops after filtering
1231
+ program = program .run_decompositions (_default_decomposition_table ())
1232
+
1233
+ # Restore ops from edge_no_decomp_namespace to aten ops
1234
+ _restore_transformed_ops_to_aten_ops (program )
1235
+ ops_set_to_not_decompose_by_program [name ].extend (final_ops_to_preserve )
1228
1236
1229
- edge_programs [name ] = program
1230
1237
edge_programs [name ] = _generate_edge_program (
1231
1238
config ,
1232
1239
program ,
1233
- preserve_ops = list ( ops_set_to_not_decompose_by_program .get (name , []) ),
1240
+ preserve_ops = ops_set_to_not_decompose_by_program .get (name , []),
1234
1241
)
1235
1242
1236
1243
edge_manager = EdgeProgramManager (
@@ -1349,9 +1356,6 @@ def to_edge_transform_and_lower( # noqa: C901
1349
1356
elif partitioner is None :
1350
1357
partitioner = {name : [] for name in aten_programs .keys ()}
1351
1358
1352
- can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1353
- partitioner , aten_programs
1354
- )
1355
1359
edge_manager = _gen_edge_manager_for_partitioners (
1356
1360
partitioner , aten_programs , config , constant_methods , generate_etrecord
1357
1361
)
@@ -1377,7 +1381,8 @@ def to_edge_transform_and_lower( # noqa: C901
1377
1381
curr_op_set , check_op_support = curr_partitioner .ops_to_not_decompose (
1378
1382
program
1379
1383
)
1380
- if not can_skip_using_EDGE_DO_NOT_DECOMP :
1384
+
1385
+ if not _can_skip_using_EDGE_DO_NOT_DECOMP (curr_partitioner , program ):
1381
1386
curr_op_set = _remove_invalid_ops_for_not_decompose (curr_op_set )
1382
1387
ops_set_to_not_decompose = ops_set_to_not_decompose .union (curr_op_set )
1383
1388
_sanity_check_graph_for_non_decomp_ops (
0 commit comments