Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/pytorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
d1b87e26e5c4343f5b56bb1e6f89b479b389bfac
export-D63846832
165 changes: 159 additions & 6 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-unsafe

import copy
import functools
import io
import logging
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union
Expand Down Expand Up @@ -56,6 +57,8 @@
get_aten_verifier,
)
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
from torch._ops import OperatorBase
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.export import ExportedProgram
from torch.export._remove_auto_functionalized_pass import (
unsafe_remove_auto_functionalized_pass,
Expand Down Expand Up @@ -925,9 +928,23 @@ def _gen_edge_manager_for_partitioners(
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
all_ops_no_decomp |= set(curr_ops_no_decomp)

program = program.run_decompositions(
_default_decomposition_table(), _preserve_ops=tuple(all_ops_no_decomp)
)
decomp_table = _default_decomposition_table()

# FIXME (tmanlaibaatar)
old_export = False
try:
from torch.export import default_decompositions
except ImportError:
old_export = True

if old_export:
for op in _collect_all_valid_cia_ops():
decomp_table[op] = _get_decomp_for_cia(op)

for op in all_ops_no_decomp:
decomp_table.pop(op, None)

program = program.run_decompositions(decomp_table)
# Among all the preserved aten ops, use the check_op_fn to do an additional
# check on which ops need to be preserved and which ops need to be decomposed
# Those which are truly preserved will be replaced with transformed ops
Expand Down Expand Up @@ -1058,6 +1075,130 @@ def to_edge_transform_and_lower(
return edge_manager


# (tmanlaibaatar) DELETE ALL THIS
@functools.lru_cache(maxsize=1)
def _materialize_cpp_cia_ops() -> None:
"""
Utility function to query C++ dispatcher to get the all
possible CIA ops and populate them into torch.ops namespace
"""
cia_ops = torch._C._dispatch_get_registrations_for_dispatch_key(
"CompositeImplicitAutograd"
)

# Materialize all CIA ops
for op in cia_ops:
namespace, op_name = tuple(op.split("::"))
split_list = op_name.split(".")
# Sometime overload could be missing
assert len(split_list) == 1 or len(split_list) == 2
op_name = split_list[0]
op_overload_name = "default"
if len(split_list) == 2:
op_overload_name = split_list[1]

_ = getattr(getattr(getattr(torch.ops, namespace), op_name), op_overload_name)


@functools.lru_cache(maxsize=1)
def _collect_all_valid_cia_ops_for_aten_namespace() -> Set["OperatorBase"]:
return _collect_all_valid_cia_ops_for_namespace("aten")


def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> Set["OperatorBase"]:
# Step 1: Materialize all ops from C++ dispatcher
_materialize_cpp_cia_ops()

# Step 2: Query all ops from python dispatcher
assert hasattr(torch.ops, namespace)
op_namespace = getattr(torch.ops, namespace)
cia_ops = set()
for op in op_namespace:
op_packet = getattr(op_namespace, op)
for overload in op_packet.overloads():
op_overload = getattr(op_packet, overload)
if _is_preservable_cia_op(op_overload):
cia_ops.add(op_overload)
return cia_ops


@functools.lru_cache(maxsize=1)
def _collect_all_valid_cia_ops() -> Set["OperatorBase"]:
cia_ops = set()
for op_namespace_name in torch.ops._dir:
if op_namespace_name != "aten":
cia_ops |= _collect_all_valid_cia_ops_for_namespace(op_namespace_name)
else:
cia_ops |= _collect_all_valid_cia_ops_for_aten_namespace()
return cia_ops


def _is_cia_op(op: "OperatorBase") -> bool:
return (
torch._C._dispatch_has_kernel_for_dispatch_key(
op.name(), torch._C.DispatchKey.CompositeImplicitAutograd
)
or torch._C.DispatchKey.CompositeImplicitAutograd in op.py_kernels
)


def _is_preservable_cia_op(op: "OperatorBase") -> bool:
return _check_valid_to_preserve(op) and _is_cia_op(op)


def _check_valid_to_preserve(op_overload: "OperatorBase"):
if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops:
return False
if op_overload in FunctionalTensor.metadata_fns:
return False

if not hasattr(op_overload, "_schema"):
return False

alias_info = len(
[i for i in op_overload._schema.arguments if i.alias_info is not None]
)

is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable

if is_mutating_or_aliasing:
return False

if not torch._C._dispatch_has_kernel(op_overload.name()):
return False

return True


def _get_decomp_for_cia(op: "OperatorBase"):
# [NOTE] Seperating out func.decompose
# Ideally we should be able to just register func.decompose but
# we can't as this decomp is gonna be registered to the py_impl.
# As a result it will infinitely recurse. So we first check if the op
# has py_impl entry for CIA and if it is we use that first. If not,
# we register C++ query to py_impl.
dk = torch._C.DispatchKey.CompositeImplicitAutograd
if dk in op.py_kernels and not isinstance(op.py_kernels[dk], torch._C.DispatchKey):
return op.py_kernels[dk]

def _special_op_to_decompose_cia(*args, **kwargs):
kernel = kwargs["kernel"]
del kwargs["kernel"]
# Can't call kernel.decompose due to infinite recursion as
# we register this kernel to py_impl directly
dk = torch._C.DispatchKey.CompositeImplicitAutograd
if torch._C._dispatch_has_kernel_for_dispatch_key(
kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd
):
return kernel._op_dk(dk, *args, **kwargs)
else:
raise AssertionError(
f"Expected {kernel} to have CompositeImplicitAutograd kernel"
)

return functools.partial(_special_op_to_decompose_cia, kernel=op)


@experimental(
"""
This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed.
Expand Down Expand Up @@ -1094,9 +1235,21 @@ def to_edge_with_preserved_ops(

for name, program in aten_programs.items():
# Decompose to Core ATen
program = program.run_decompositions(
_default_decomposition_table(), _preserve_ops=preserve_ops
)
decomp_table = _default_decomposition_table()

old_export = False
try:
from torch.export import default_decompositions
except ImportError:
old_export = True

if old_export:
for op in _collect_all_valid_cia_ops():
decomp_table[op] = _get_decomp_for_cia(op)

for op in preserve_ops:
decomp_table.pop(op, None)
program = program.run_decompositions(decomp_table)
edge_programs[name] = _generate_edge_program(
name, config, program, list(preserve_ops)
)
Expand Down
10 changes: 6 additions & 4 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,10 +573,12 @@ def get_num_nondecomposed_ops(self, ep, partitioner):
# which pass the filter_ops fn given by the partitioner
reference_ep = copy.deepcopy(ep)
aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep)
reference_decomp_ep = reference_ep.run_decompositions(
decomp_table=_default_decomposition_table(),
_preserve_ops=tuple(aten_ops_not_decomposed),
)

decomp_table = _default_decomposition_table()
for op in aten_ops_not_decomposed:
decomp_table.pop(op, None)

reference_decomp_ep = reference_ep.run_decompositions(decomp_table)
num_non_decomposed_aten_ops = 0
for node in reference_decomp_ep.graph.nodes:
if (
Expand Down
9 changes: 7 additions & 2 deletions exir/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from executorch.exir.types import ValueSpec

from torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass # @manual
from torch._decomp import core_aten_decompositions, get_decompositions
from torch._decomp import get_decompositions
from torch._dynamo.guards import Guard
from torch._functorch.eager_transforms import _maybe_unwrap_functional_tensor
from torch.func import functionalize
Expand Down Expand Up @@ -631,7 +631,12 @@ def _default_decomposition_table(
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
return get_decompositions(decomp_opset)
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
return core_aten_decompositions()
try:
from torch.export import default_decompositions
return default_decompositions()
except ImportError:
from torch._decomp import core_aten_decompositions
return core_aten_decompositions()


def dynamo_trace(
Expand Down
Loading