Skip to content

Commit b991271

Browse files
Fix op decomposition issue when multiple partitioners with conflicting expectations are run (#14458)
1 parent 95888a4 commit b991271

File tree

3 files changed

+64
-59
lines changed

3 files changed

+64
-59
lines changed

exir/program/_program.py

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import io
1212
import logging
1313
import os
14+
from collections import defaultdict
1415
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Type, Union
1516

1617
import torch
@@ -1136,25 +1137,16 @@ def keep(op):
11361137

11371138

11381139
def _can_skip_using_EDGE_DO_NOT_DECOMP(
1139-
partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram]
1140+
partitioner: Partitioner, program: ExportedProgram
11401141
) -> bool:
11411142
# THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
11421143
# has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
11431144
# fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
11441145
# and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
11451146
# EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
11461147
# As a temp fix, we give a more reliable path for backends that do not specify check_op_support
1147-
can_skip_using_EDGE_DO_NOT_DECOMP = True
1148-
for name, program in aten_programs.items():
1149-
if partitioner is not None:
1150-
for curr_partitioner in partitioner.get(name, []):
1151-
(
1152-
curr_ops_no_decomp,
1153-
check_op_support,
1154-
) = curr_partitioner.ops_to_not_decompose(program)
1155-
if check_op_support is not None:
1156-
can_skip_using_EDGE_DO_NOT_DECOMP = False
1157-
return can_skip_using_EDGE_DO_NOT_DECOMP
1148+
_, check_op_support = partitioner.ops_to_not_decompose(program)
1149+
return check_op_support is None
11581150

11591151

11601152
def _gen_edge_manager_for_partitioners(
@@ -1177,60 +1169,75 @@ def _gen_edge_manager_for_partitioners(
11771169
on nodes with preserved aten targets. They are then replaces with transformed ops to
11781170
keep them through the second pass of decompositions
11791171
"""
1180-
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1181-
partitioner, aten_programs
1182-
)
1183-
ops_set_to_not_decompose_by_program = {}
1172+
ops_set_to_not_decompose_by_program = defaultdict(list)
11841173
edge_programs: Dict[str, ExportedProgram] = {}
11851174
for name, program in aten_programs.items():
11861175
# Functionalize program before asking partitioners to preserve ops
11871176
program = program.run_decompositions({})
11881177

11891178
if partitioner is not None:
1190-
# preserve all ops listed by all partitioners first
1191-
all_ops_no_decomp = set()
1192-
all_ops_no_decomp_needing_preservation = []
1193-
for curr_partitioner in partitioner.get(name, []):
1179+
partitioners_for_program = partitioner.get(name, [])
1180+
final_ops_to_preserve = set()
1181+
1182+
# Decompose by default if there are no partitioners for the method
1183+
if not partitioners_for_program:
1184+
program = program.run_decompositions(_default_decomposition_table())
1185+
1186+
# Process each partitioner individually using their specific requirements
1187+
for curr_partitioner in partitioners_for_program:
11941188
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1195-
all_ops_no_decomp |= set(curr_ops_no_decomp)
11961189

1197-
# If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1198-
# Otherwise there will be issues
1199-
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1200-
all_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1201-
list(all_ops_no_decomp)
1202-
)
1203-
all_ops_no_decomp = set(all_ops_no_decomp)
1204-
1205-
# Run default decompositions, except for those in all_ops_no_decomp
1206-
table = _default_decomposition_table()
1207-
for op in all_ops_no_decomp:
1208-
if table.pop(op, None) is not None:
1209-
all_ops_no_decomp_needing_preservation.append(op)
1210-
program = program.run_decompositions(table)
1211-
1212-
# Among all the preserved aten ops, use the check_op_fn to do an additional
1213-
# check on which ops need to be preserved and which ops need to be decomposed
1214-
# Those which are truly preserved will be replaced with transformed ops
1215-
if can_skip_using_EDGE_DO_NOT_DECOMP:
1216-
ops_set_to_not_decompose_by_program[name] = (
1217-
all_ops_no_decomp_needing_preservation
1218-
)
1219-
else:
1220-
ops_set_to_not_decompose_by_program[name] = (
1221-
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
1222-
or []
1190+
# Check if this partitioner can skip using EDGE_DO_NOT_DECOMP
1191+
can_skip_using_edge_do_not_decomp = _can_skip_using_EDGE_DO_NOT_DECOMP(
1192+
curr_partitioner, program
12231193
)
12241194

1225-
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1226-
program = program.run_decompositions(_default_decomposition_table())
1227-
_restore_transformed_ops_to_aten_ops(program)
1195+
if can_skip_using_edge_do_not_decomp:
1196+
# Preserve all ops in curr_ops_no_decomp from decomposition
1197+
table = _default_decomposition_table()
1198+
ops_needing_preservation = []
1199+
1200+
for op in curr_ops_no_decomp:
1201+
if table.pop(op, None) is not None:
1202+
ops_needing_preservation.append(op)
1203+
1204+
program = program.run_decompositions(table)
1205+
final_ops_to_preserve.update(ops_needing_preservation)
1206+
else:
1207+
# EDGE_DO_NOT_DECOMP path for the partitioner
1208+
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1209+
curr_ops_no_decomp
1210+
)
1211+
1212+
# Apply decompositions with this partitioner's preserved ops
1213+
table = _default_decomposition_table()
1214+
for op in curr_ops_no_decomp:
1215+
table.pop(op, None)
1216+
1217+
# First pass of decompositions with this partitioner's preserved ops
1218+
program = program.run_decompositions(table)
1219+
1220+
# Filter ops using EDGE_DO_NOT_DECOMP
1221+
temp_partitioner_dict = {name: [curr_partitioner]}
1222+
preserved_ops = (
1223+
_replace_aten_ops_with_transformed_ops(
1224+
name, program, temp_partitioner_dict
1225+
)
1226+
or []
1227+
)
1228+
final_ops_to_preserve.update(preserved_ops)
1229+
1230+
# Second pass of decompositions with this partitioner's preserved ops after filtering
1231+
program = program.run_decompositions(_default_decomposition_table())
1232+
1233+
# Restore ops from edge_no_decomp_namespace to aten ops
1234+
_restore_transformed_ops_to_aten_ops(program)
1235+
ops_set_to_not_decompose_by_program[name].extend(final_ops_to_preserve)
12281236

1229-
edge_programs[name] = program
12301237
edge_programs[name] = _generate_edge_program(
12311238
config,
12321239
program,
1233-
preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])),
1240+
preserve_ops=ops_set_to_not_decompose_by_program.get(name, []),
12341241
)
12351242

12361243
edge_manager = EdgeProgramManager(
@@ -1349,9 +1356,6 @@ def to_edge_transform_and_lower( # noqa: C901
13491356
elif partitioner is None:
13501357
partitioner = {name: [] for name in aten_programs.keys()}
13511358

1352-
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1353-
partitioner, aten_programs
1354-
)
13551359
edge_manager = _gen_edge_manager_for_partitioners(
13561360
partitioner, aten_programs, config, constant_methods, generate_etrecord
13571361
)
@@ -1377,7 +1381,8 @@ def to_edge_transform_and_lower( # noqa: C901
13771381
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
13781382
program
13791383
)
1380-
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1384+
1385+
if not _can_skip_using_EDGE_DO_NOT_DECOMP(curr_partitioner, program):
13811386
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
13821387
ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
13831388
_sanity_check_graph_for_non_decomp_ops(

export/target_recipes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def get_ios_recipe(
9393
# pyre-ignore
9494
"ios-arm64-coreml-fp32": [CoreMLRecipeType.FP32, XNNPackRecipeType.FP32],
9595
# pyre-ignore
96-
"ios-arm64-coreml-fp16": [CoreMLRecipeType.FP16],
96+
"ios-arm64-coreml-fp16": [CoreMLRecipeType.FP16, XNNPackRecipeType.FP32],
9797
# pyre-ignore
9898
"ios-arm64-coreml-int8": [CoreMLRecipeType.PT2E_INT8_STATIC],
9999
}
@@ -165,7 +165,7 @@ def get_android_recipe(
165165

166166
android_configs: Dict[str, List[RecipeType]] = {
167167
# pyre-ignore
168-
"android-arm64-snapdragon-fp16": [QNNRecipeType.FP16],
168+
"android-arm64-snapdragon-fp16": [QNNRecipeType.FP16, XNNPackRecipeType.FP32],
169169
}
170170

171171
if target_config not in android_configs:

export/tests/test_target_recipes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def _get_model_test_configs(
387387
@classmethod
388388
def _get_recipes(cls) -> Dict[str, Tuple[ExportRecipe, str]]:
389389
"""Get available recipes with their configurations based on platform."""
390-
all_recipes = {}
390+
all_recipes: Dict[str, Tuple[ExportRecipe, str]] = {}
391391

392392
# Add iOS recipes
393393
if is_supported_platform_for_coreml_lowering():

0 commit comments

Comments
 (0)