Skip to content

Commit d31ede0

Browse files
committed
up
1 parent 8c65c95 commit d31ede0

File tree

1 file changed

+49
-31
lines changed

1 file changed

+49
-31
lines changed

exir/program/_program.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,14 +1106,31 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
11061106
for name, program in aten_programs.items():
11071107
if partitioner is not None:
11081108
for curr_partitioner in partitioner.get(name, []):
1109-
curr_ops_no_decomp, check_op_support = (
1110-
curr_partitioner.ops_to_not_decompose(program)
1111-
)
1109+
(
1110+
curr_ops_no_decomp,
1111+
check_op_support,
1112+
) = curr_partitioner.ops_to_not_decompose(program)
11121113
if check_op_support is not None:
11131114
can_skip_using_EDGE_DO_NOT_DECOMP = False
11141115
return can_skip_using_EDGE_DO_NOT_DECOMP
11151116

11161117

1118+
def _replace_view_with_view_copy(program: ExportedProgram) -> ExportedProgram:
1119+
program = program.run_decompositions({})
1120+
new_gm = ReplaceViewOpsWithViewCopyOpsPass()(program.graph_module).graph_module
1121+
program = ExportedProgram(
1122+
root=new_gm,
1123+
graph=new_gm.graph,
1124+
graph_signature=_get_updated_graph_signature(program.graph_signature, new_gm),
1125+
state_dict=program.state_dict,
1126+
range_constraints=program.range_constraints,
1127+
module_call_graph=program.module_call_graph,
1128+
example_inputs=program.example_inputs,
1129+
constants=program.constants,
1130+
)
1131+
return program
1132+
1133+
11171134
def _gen_edge_manager_for_partitioners(
11181135
partitioner: Dict[str, List[Partitioner]],
11191136
aten_programs: Dict[str, ExportedProgram],
@@ -1133,58 +1150,55 @@ def _gen_edge_manager_for_partitioners(
11331150
on nodes with preserved aten targets. They are then replaces with transformed ops to
11341151
keep them through the second pass of decompositions
11351152
"""
1153+
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1154+
partitioner, aten_programs
1155+
)
11361156
ops_set_to_not_decompose_by_program = {}
11371157
edge_programs: Dict[str, ExportedProgram] = {}
11381158
for name, program in aten_programs.items():
1139-
# Functionalize program without doing any decompositions
1140-
program = program.run_decompositions({})
1141-
ReplaceViewOpsWithViewCopyOpsPass()(program.graph_module)
1142-
1143-
print(program)
1144-
11451159
if partitioner is not None:
11461160
# preserve all ops listed by all partitioners first
11471161
all_ops_no_decomp = set()
1162+
all_ops_no_decomp_needing_preservation = []
11481163
for curr_partitioner in partitioner.get(name, []):
11491164
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1150-
<<<<<<< HEAD
1151-
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1152-
curr_ops_no_decomp
1153-
)
1154-
=======
1155-
>>>>>>> ec44f8478 (updates)
11561165
all_ops_no_decomp |= set(curr_ops_no_decomp)
1157-
1166+
11581167
# If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1159-
# Otherwise there will be issues
1168+
# Otherwise there will be issues
11601169
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1161-
all_ops_no_decomp = _remove_invalid_ops_for_not_decompose(list(all_ops_no_decomp))
1170+
all_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1171+
list(all_ops_no_decomp)
1172+
)
11621173
all_ops_no_decomp = set(all_ops_no_decomp)
11631174

11641175
# Run default decompositions, except for those in all_ops_no_decomp
11651176
table = _default_decomposition_table()
11661177
for op in all_ops_no_decomp:
1167-
<<<<<<< HEAD
1168-
table.pop(op, None)
1169-
1170-
=======
11711178
if table.pop(op, None) is not None:
11721179
all_ops_no_decomp_needing_preservation.append(op)
1173-
>>>>>>> ec44f8478 (updates)
11741180
program = program.run_decompositions(table)
11751181

11761182
# Among all the preserved aten ops, use the check_op_fn to do an additional
11771183
# check on which ops need to be preserved and which ops need to be decomposed
11781184
# Those which are truly preserved will be replaced with transformed ops
1179-
ops_set_to_not_decompose_by_program[name] = (
1180-
_replace_aten_ops_with_transformed_ops(name, program, partitioner) or []
1181-
)
1182-
program = program.run_decompositions(_default_decomposition_table())
1185+
if can_skip_using_EDGE_DO_NOT_DECOMP:
1186+
ops_set_to_not_decompose_by_program[
1187+
name
1188+
] = all_ops_no_decomp_needing_preservation
1189+
else:
1190+
ops_set_to_not_decompose_by_program[name] = (
1191+
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
1192+
or []
1193+
)
11831194

1184-
_restore_transformed_ops_to_aten_ops(program)
1195+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1196+
program = program.run_decompositions(_default_decomposition_table())
1197+
_restore_transformed_ops_to_aten_ops(program)
11851198

1199+
# Edge will complain if there are view ops requested for preservation, so we replace them with view_copy
1200+
program = _replace_view_with_view_copy(program)
11861201
edge_programs[name] = program
1187-
11881202
edge_programs[name] = _generate_edge_program(
11891203
name,
11901204
config,
@@ -1229,7 +1243,7 @@ def collect_named_data_store_outputs(
12291243

12301244

12311245
@et_logger("to_edge_transform_and_lower")
1232-
def to_edge_transform_and_lower(
1246+
def to_edge_transform_and_lower( # noqa: C901
12331247
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
12341248
transform_passes: Optional[
12351249
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
@@ -1294,6 +1308,9 @@ def to_edge_transform_and_lower(
12941308
elif partitioner is None:
12951309
partitioner = {name: [] for name in aten_programs.keys()}
12961310

1311+
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1312+
partitioner, aten_programs
1313+
)
12971314
edge_manager = _gen_edge_manager_for_partitioners(
12981315
partitioner, aten_programs, config, constant_methods
12991316
)
@@ -1319,7 +1336,8 @@ def to_edge_transform_and_lower(
13191336
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
13201337
program
13211338
)
1322-
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
1339+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1340+
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
13231341
ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
13241342
_sanity_check_graph_for_non_decomp_ops(
13251343
name,

0 commit comments

Comments
 (0)