Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/pytorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
00e3eea170ce5db8ea9c62ce5e48f13886cd6d20
export-D62539799
9 changes: 6 additions & 3 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,9 +924,12 @@ def _gen_edge_manager_for_partitioners(
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
all_ops_no_decomp |= set(curr_ops_no_decomp)

program = program.run_decompositions(
_default_decomposition_table(), _preserve_ops=tuple(all_ops_no_decomp)
)
decomp_table = _default_decomposition_table()
for op in all_ops_no_decomp:
if op in decomp_table:
del decomp_table[op]

program = program.run_decompositions(decomp_table)
# Among all the preserved aten ops, use the check_op_fn to do an additional
# check on which ops need to be preserved and which ops need to be decomposed
# Those which are truly preserved will be replaced with transformed ops
Expand Down
10 changes: 6 additions & 4 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,12 @@ def get_num_nondecomposed_ops(self, ep, partitioner):
# which pass the filter_ops fn given by the partitioner
reference_ep = copy.deepcopy(ep)
aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep)
reference_decomp_ep = reference_ep.run_decompositions(
decomp_table=_default_decomposition_table(),
_preserve_ops=tuple(aten_ops_not_decomposed),
)
decomp_table = _default_decomposition_table()
for op in aten_ops_not_decomposed:
if op in decomp_table:
del decomp_table[op]
reference_decomp_ep = reference_ep.run_decompositions(decomp_table)

num_non_decomposed_aten_ops = 0
for node in reference_decomp_ep.graph.nodes:
if (
Expand Down
Loading