diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index 21a0ea5d478..9aaea8851d4 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -d1b87e26e5c4343f5b56bb1e6f89b479b389bfac +export-D64151426 diff --git a/exir/program/_program.py b/exir/program/_program.py index fa1d84db7c6..f7d7026bdc2 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -925,9 +925,12 @@ def _gen_edge_manager_for_partitioners( curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) all_ops_no_decomp |= set(curr_ops_no_decomp) - program = program.run_decompositions( - _default_decomposition_table(), _preserve_ops=tuple(all_ops_no_decomp) - ) + table = _default_decomposition_table() + + for op in all_ops_no_decomp: + table.pop(op, None) + + program = program.run_decompositions(table) # Among all the preserved aten ops, use the check_op_fn to do an additional # check on which ops need to be preserved and which ops need to be decomposed # Those which are truly preserved will be replaced with transformed ops @@ -1097,9 +1100,10 @@ def to_edge_with_preserved_ops( for name, program in aten_programs.items(): # Decompose to Core ATen - program = program.run_decompositions( - _default_decomposition_table(), _preserve_ops=preserve_ops - ) + table = _default_decomposition_table() + for op in preserve_ops: + table.pop(op, None) + program = program.run_decompositions(table) edge_programs[name] = _generate_edge_program( name, config, program, list(preserve_ops) ) diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 73eea7b93ef..98199ac0dc1 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -573,10 +573,10 @@ def get_num_nondecomposed_ops(self, ep, partitioner): # which pass the filter_ops fn given by the partitioner reference_ep = copy.deepcopy(ep) aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep) - reference_decomp_ep = reference_ep.run_decompositions( - decomp_table=_default_decomposition_table(), - _preserve_ops=tuple(aten_ops_not_decomposed), - ) + table = _default_decomposition_table() + for op in aten_ops_not_decomposed: + table.pop(op, None) + reference_decomp_ep = reference_ep.run_decompositions(decomp_table=table) num_non_decomposed_aten_ops = 0 for node in reference_decomp_ep.graph.nodes: if ( diff --git a/exir/tracer.py b/exir/tracer.py index c4593cca8e3..09bb4780780 100644 --- a/exir/tracer.py +++ b/exir/tracer.py @@ -44,9 +44,10 @@ from executorch.exir.types import ValueSpec from torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass # @manual -from torch._decomp import core_aten_decompositions, get_decompositions +from torch._decomp import get_decompositions from torch._dynamo.guards import Guard from torch._functorch.eager_transforms import _maybe_unwrap_functional_tensor +from torch.export import default_decompositions from torch.func import functionalize from torch.fx.operator_schemas import normalize_function from torch.utils._pytree import TreeSpec @@ -631,7 +632,7 @@ def _default_decomposition_table( # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e... return get_decompositions(decomp_opset) # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir.... - return core_aten_decompositions() + return default_decompositions() def dynamo_trace( diff --git a/install_requirements.py b/install_requirements.py index 90acb9467ce..dbdbecdaa3f 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -106,7 +106,7 @@ def python_is_compatible(): # NOTE: If a newly-fetched version of the executorch repo changes the value of # NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -NIGHTLY_VERSION = "dev20241007" +NIGHTLY_VERSION = "dev20241019" # The pip repository that hosts nightly torch packages. TORCH_NIGHTLY_URL = "https://download.pytorch.org/whl/nightly/cpu"