From 7901e1fe864bc34efcd11105d61b0789467136c8 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Tue, 8 Oct 2024 22:53:41 -0700 Subject: [PATCH] Extend preservable op list to custom ops (#5850) Summary: X-link: https://github.com/pytorch/pytorch/pull/137289 User can preserve custom ops as well. We do so by extending current logic to also query CIA ops from all possible op namespaces. Test Plan: Imported from OSS Differential Revision: D63846832 Pulled By: tugsbayasgalan --- .ci/docker/ci_commit_pins/pytorch.txt | 2 +- exir/program/_program.py | 165 +++++++++++++++++++++++++- exir/program/test/test_program.py | 10 +- exir/tracer.py | 9 +- 4 files changed, 173 insertions(+), 13 deletions(-) diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index 21a0ea5d478..fc690a42352 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -d1b87e26e5c4343f5b56bb1e6f89b479b389bfac +export-D63846832 diff --git a/exir/program/_program.py b/exir/program/_program.py index 144cd0d0e8e..9d0184f66cc 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -7,6 +7,7 @@ # pyre-unsafe import copy +import functools import io import logging from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union @@ -56,6 +57,8 @@ get_aten_verifier, ) from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass +from torch._ops import OperatorBase +from torch._subclasses.functional_tensor import FunctionalTensor from torch.export import ExportedProgram from torch.export._remove_auto_functionalized_pass import ( unsafe_remove_auto_functionalized_pass, @@ -925,9 +928,23 @@ def _gen_edge_manager_for_partitioners( curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) all_ops_no_decomp |= set(curr_ops_no_decomp) - program = program.run_decompositions( - _default_decomposition_table(), _preserve_ops=tuple(all_ops_no_decomp) - ) + decomp_table = _default_decomposition_table() + + # FIXME (tmanlaibaatar) + old_export = False + try: + from torch.export import default_decompositions + except ImportError: + old_export = True + + if old_export: + for op in _collect_all_valid_cia_ops(): + decomp_table[op] = _get_decomp_for_cia(op) + + for op in all_ops_no_decomp: + decomp_table.pop(op, None) + + program = program.run_decompositions(decomp_table) # Among all the preserved aten ops, use the check_op_fn to do an additional # check on which ops need to be preserved and which ops need to be decomposed # Those which are truly preserved will be replaced with transformed ops @@ -1058,6 +1075,130 @@ def to_edge_transform_and_lower( return edge_manager +# (tmanlaibaatar) DELETE ALL THIS +@functools.lru_cache(maxsize=1) +def _materialize_cpp_cia_ops() -> None: + """ + Utility function to query C++ dispatcher to get the all + possible CIA ops and populate them into torch.ops namespace + """ + cia_ops = torch._C._dispatch_get_registrations_for_dispatch_key( + "CompositeImplicitAutograd" + ) + + # Materialize all CIA ops + for op in cia_ops: + namespace, op_name = tuple(op.split("::")) + split_list = op_name.split(".") + # Sometime overload could be missing + assert len(split_list) == 1 or len(split_list) == 2 + op_name = split_list[0] + op_overload_name = "default" + if len(split_list) == 2: + op_overload_name = split_list[1] + + _ = getattr(getattr(getattr(torch.ops, namespace), op_name), op_overload_name) + + +@functools.lru_cache(maxsize=1) +def _collect_all_valid_cia_ops_for_aten_namespace() -> Set["OperatorBase"]: + return _collect_all_valid_cia_ops_for_namespace("aten") + + +def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> Set["OperatorBase"]: + # Step 1: Materialize all ops from C++ dispatcher + _materialize_cpp_cia_ops() + + # Step 2: Query all ops from python dispatcher + assert hasattr(torch.ops, namespace) + op_namespace = getattr(torch.ops, namespace) + cia_ops = set() + for op in op_namespace: + op_packet = getattr(op_namespace, op) + for overload in op_packet.overloads(): + op_overload = getattr(op_packet, overload) + if _is_preservable_cia_op(op_overload): + cia_ops.add(op_overload) + return cia_ops + + +@functools.lru_cache(maxsize=1) +def _collect_all_valid_cia_ops() -> Set["OperatorBase"]: + cia_ops = set() + for op_namespace_name in torch.ops._dir: + if op_namespace_name != "aten": + cia_ops |= _collect_all_valid_cia_ops_for_namespace(op_namespace_name) + else: + cia_ops |= _collect_all_valid_cia_ops_for_aten_namespace() + return cia_ops + + +def _is_cia_op(op: "OperatorBase") -> bool: + return ( + torch._C._dispatch_has_kernel_for_dispatch_key( + op.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ) + or torch._C.DispatchKey.CompositeImplicitAutograd in op.py_kernels + ) + + +def _is_preservable_cia_op(op: "OperatorBase") -> bool: + return _check_valid_to_preserve(op) and _is_cia_op(op) + + +def _check_valid_to_preserve(op_overload: "OperatorBase"): + if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops: + return False + if op_overload in FunctionalTensor.metadata_fns: + return False + + if not hasattr(op_overload, "_schema"): + return False + + alias_info = len( + [i for i in op_overload._schema.arguments if i.alias_info is not None] + ) + + is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable + + if is_mutating_or_aliasing: + return False + + if not torch._C._dispatch_has_kernel(op_overload.name()): + return False + + return True + + +def _get_decomp_for_cia(op: "OperatorBase"): + # [NOTE] Seperating out func.decompose + # Ideally we should be able to just register func.decompose but + # we can't as this decomp is gonna be registered to the py_impl. + # As a result it will infinitely recurse. So we first check if the op + # has py_impl entry for CIA and if it is we use that first. If not, + # we register C++ query to py_impl. + dk = torch._C.DispatchKey.CompositeImplicitAutograd + if dk in op.py_kernels and not isinstance(op.py_kernels[dk], torch._C.DispatchKey): + return op.py_kernels[dk] + + def _special_op_to_decompose_cia(*args, **kwargs): + kernel = kwargs["kernel"] + del kwargs["kernel"] + # Can't call kernel.decompose due to infinite recursion as + # we register this kernel to py_impl directly + dk = torch._C.DispatchKey.CompositeImplicitAutograd + if torch._C._dispatch_has_kernel_for_dispatch_key( + kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ): + return kernel._op_dk(dk, *args, **kwargs) + else: + raise AssertionError( + f"Expected {kernel} to have CompositeImplicitAutograd kernel" + ) + + return functools.partial(_special_op_to_decompose_cia, kernel=op) + + @experimental( """ This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed. @@ -1094,9 +1235,21 @@ def to_edge_with_preserved_ops( for name, program in aten_programs.items(): # Decompose to Core ATen - program = program.run_decompositions( - _default_decomposition_table(), _preserve_ops=preserve_ops - ) + decomp_table = _default_decomposition_table() + + old_export = False + try: + from torch.export import default_decompositions + except ImportError: + old_export = True + + if old_export: + for op in _collect_all_valid_cia_ops(): + decomp_table[op] = _get_decomp_for_cia(op) + + for op in preserve_ops: + decomp_table.pop(op, None) + program = program.run_decompositions(decomp_table) edge_programs[name] = _generate_edge_program( name, config, program, list(preserve_ops) ) diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 73eea7b93ef..bf9951e93eb 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -573,10 +573,12 @@ def get_num_nondecomposed_ops(self, ep, partitioner): # which pass the filter_ops fn given by the partitioner reference_ep = copy.deepcopy(ep) aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep) - reference_decomp_ep = reference_ep.run_decompositions( - decomp_table=_default_decomposition_table(), - _preserve_ops=tuple(aten_ops_not_decomposed), - ) + + decomp_table = _default_decomposition_table() + for op in aten_ops_not_decomposed: + decomp_table.pop(op, None) + + reference_decomp_ep = reference_ep.run_decompositions(decomp_table) num_non_decomposed_aten_ops = 0 for node in reference_decomp_ep.graph.nodes: if ( diff --git a/exir/tracer.py b/exir/tracer.py index c4593cca8e3..9205fcd85f9 100644 --- a/exir/tracer.py +++ b/exir/tracer.py @@ -44,7 +44,7 @@ from executorch.exir.types import ValueSpec from torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass # @manual -from torch._decomp import core_aten_decompositions, get_decompositions +from torch._decomp import get_decompositions from torch._dynamo.guards import Guard from torch._functorch.eager_transforms import _maybe_unwrap_functional_tensor from torch.func import functionalize @@ -631,7 +631,12 @@ def _default_decomposition_table( # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e... return get_decompositions(decomp_opset) # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir.... - return core_aten_decompositions() + try: + from torch.export import default_decompositions + return default_decompositions() + except ImportError: + from torch._decomp import core_aten_decompositions + return core_aten_decompositions() def dynamo_trace(