diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index b291722c3f0..1369f41b48d 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -00e3eea170ce5db8ea9c62ce5e48f13886cd6d20 +export-D62539799 diff --git a/exir/program/_program.py b/exir/program/_program.py index 6b72d190f9d..2a39d909de5 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -924,9 +924,12 @@ def _gen_edge_manager_for_partitioners( curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) all_ops_no_decomp |= set(curr_ops_no_decomp) - program = program.run_decompositions( - _default_decomposition_table(), _preserve_ops=tuple(all_ops_no_decomp) - ) + decomp_table = _default_decomposition_table() + for op in all_ops_no_decomp: + if op in decomp_table: + del decomp_table[op] + + program = program.run_decompositions(decomp_table) # Among all the preserved aten ops, use the check_op_fn to do an additional # check on which ops need to be preserved and which ops need to be decomposed # Those which are truly preserved will be replaced with transformed ops diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 73f023e778b..d3ece089440 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -574,10 +574,12 @@ def get_num_nondecomposed_ops(self, ep, partitioner): # which pass the filter_ops fn given by the partitioner reference_ep = copy.deepcopy(ep) aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep) - reference_decomp_ep = reference_ep.run_decompositions( - decomp_table=_default_decomposition_table(), - _preserve_ops=tuple(aten_ops_not_decomposed), - ) + decomp_table = _default_decomposition_table() + for op in aten_ops_not_decomposed: + if op in decomp_table: + del decomp_table[op] + reference_decomp_ep = reference_ep.run_decompositions(decomp_table) + num_non_decomposed_aten_ops = 0 for node in reference_decomp_ep.graph.nodes: if (