Skip to content

Commit 8c65c95

Browse files
committed
updates
1 parent 0db8bdd commit 8c65c95

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

exir/program/_program.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,7 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
11131113
can_skip_using_EDGE_DO_NOT_DECOMP = False
11141114
return can_skip_using_EDGE_DO_NOT_DECOMP
11151115

1116+
11161117
def _gen_edge_manager_for_partitioners(
11171118
partitioner: Dict[str, List[Partitioner]],
11181119
aten_programs: Dict[str, ExportedProgram],
@@ -1135,22 +1136,43 @@ def _gen_edge_manager_for_partitioners(
11351136
ops_set_to_not_decompose_by_program = {}
11361137
edge_programs: Dict[str, ExportedProgram] = {}
11371138
for name, program in aten_programs.items():
1139+
# Functionalize program without doing any decompositions
1140+
program = program.run_decompositions({})
1141+
ReplaceViewOpsWithViewCopyOpsPass()(program.graph_module)
1142+
1143+
print(program)
1144+
11381145
if partitioner is not None:
11391146
# preserve all ops listed by all partitioners first
11401147
all_ops_no_decomp = set()
11411148
for curr_partitioner in partitioner.get(name, []):
11421149
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1150+
<<<<<<< HEAD
11431151
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
11441152
curr_ops_no_decomp
11451153
)
1154+
=======
1155+
>>>>>>> ec44f8478 (updates)
11461156
all_ops_no_decomp |= set(curr_ops_no_decomp)
1147-
1157+
1158+
# If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1159+
# Otherwise there will be issues
1160+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1161+
all_ops_no_decomp = _remove_invalid_ops_for_not_decompose(list(all_ops_no_decomp))
1162+
all_ops_no_decomp = set(all_ops_no_decomp)
1163+
1164+
# Run default decompositions, except for those in all_ops_no_decomp
11481165
table = _default_decomposition_table()
1149-
11501166
for op in all_ops_no_decomp:
1167+
<<<<<<< HEAD
11511168
table.pop(op, None)
11521169

1170+
=======
1171+
if table.pop(op, None) is not None:
1172+
all_ops_no_decomp_needing_preservation.append(op)
1173+
>>>>>>> ec44f8478 (updates)
11531174
program = program.run_decompositions(table)
1175+
11541176
# Among all the preserved aten ops, use the check_op_fn to do an additional
11551177
# check on which ops need to be preserved and which ops need to be decomposed
11561178
# Those which are truly preserved will be replaced with transformed ops

0 commit comments

Comments
 (0)