Skip to content

Commit b6da44d

Browse files
authored
Merge branch 'main' into shoumikhin-patch-3
2 parents 9ed0430 + 5e03d33 commit b6da44d

File tree

2 files changed

+49
-10
lines changed

2 files changed

+49
-10
lines changed

exir/program/_program.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,27 @@ def keep(op):
10941094
return list(filter(keep, preserve_ops))
10951095

10961096

1097+
def _can_skip_using_EDGE_DO_NOT_DECOMP(
1098+
partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram]
1099+
) -> bool:
1100+
# THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
1101+
# has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
1102+
# fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
1103+
# and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
1104+
# EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
1105+
# As a temp fix, we give a more reliable path for backends that do not specify check_op_support
1106+
can_skip_using_EDGE_DO_NOT_DECOMP = True
1107+
for name, program in aten_programs.items():
1108+
if partitioner is not None:
1109+
for curr_partitioner in partitioner.get(name, []):
1110+
curr_ops_no_decomp, check_op_support = (
1111+
curr_partitioner.ops_to_not_decompose(program)
1112+
)
1113+
if check_op_support is not None:
1114+
can_skip_using_EDGE_DO_NOT_DECOMP = False
1115+
return can_skip_using_EDGE_DO_NOT_DECOMP
1116+
1117+
10971118
def _gen_edge_manager_for_partitioners(
10981119
partitioner: Dict[str, List[Partitioner]],
10991120
aten_programs: Dict[str, ExportedProgram],
@@ -1113,37 +1134,54 @@ def _gen_edge_manager_for_partitioners(
11131134
on nodes with preserved aten targets. They are then replaces with transformed ops to
11141135
keep them through the second pass of decompositions
11151136
"""
1137+
1138+
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1139+
partitioner, aten_programs
1140+
)
1141+
11161142
ops_set_to_not_decompose_by_program = {}
11171143
edge_programs: Dict[str, ExportedProgram] = {}
11181144
for name, program in aten_programs.items():
11191145
if partitioner is not None:
11201146
# preserve all ops listed by all partitioners first
11211147
all_ops_no_decomp = set()
1148+
1149+
# This holds the subset of all_ops_no_decomp that actually need preservation, i.e.,
1150+
# the ones where the decomposition table has an entry for the op
1151+
all_ops_no_decomp_needing_preservation = []
11221152
for curr_partitioner in partitioner.get(name, []):
11231153
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1124-
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1125-
curr_ops_no_decomp
1126-
)
1154+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1155+
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1156+
curr_ops_no_decomp
1157+
)
11271158
all_ops_no_decomp |= set(curr_ops_no_decomp)
11281159

11291160
table = _default_decomposition_table()
11301161

11311162
for op in all_ops_no_decomp:
1132-
table.pop(op, None)
1163+
if table.pop(op, None) is not None:
1164+
all_ops_no_decomp_needing_preservation.append(op)
11331165

11341166
program = program.run_decompositions(table)
11351167
# Among all the preserved aten ops, use the check_op_fn to do an additional
11361168
# check on which ops need to be preserved and which ops need to be decomposed
11371169
# Those which are truly preserved will be replaced with transformed ops
1138-
ops_set_to_not_decompose_by_program[name] = (
1139-
_replace_aten_ops_with_transformed_ops(name, program, partitioner) or []
1140-
)
1141-
program = program.run_decompositions(_default_decomposition_table())
1170+
if can_skip_using_EDGE_DO_NOT_DECOMP:
1171+
ops_set_to_not_decompose_by_program[name] = (
1172+
all_ops_no_decomp_needing_preservation
1173+
)
1174+
else:
1175+
ops_set_to_not_decompose_by_program[name] = (
1176+
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
1177+
or []
1178+
)
11421179

1143-
_restore_transformed_ops_to_aten_ops(program)
1180+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1181+
program = program.run_decompositions(_default_decomposition_table())
1182+
_restore_transformed_ops_to_aten_ops(program)
11441183

11451184
edge_programs[name] = program
1146-
11471185
edge_programs[name] = _generate_edge_program(
11481186
name,
11491187
config,

scripts/build_apple_frameworks.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ libmpsdelegate.a,\
5252

5353
FRAMEWORK_BACKEND_XNNPACK="backend_xnnpack:\
5454
libXNNPACK.a,\
55+
libkleidiai.a,\
5556
libxnnpack_backend.a,\
5657
libxnnpack-microkernels-prod.a,\
5758
:"

0 commit comments

Comments
 (0)