Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/nxp/nxp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def preprocess(
# Otherwise, we get violation that this op is not part of ATen Core ops.
edge_program._verifiers = [
EXIREdgeDialectVerifier(
class_only=True, exception_list=[torch.ops.aten.max_pool2d.default]
class_only=True, core_aten_ops_exception_list=[torch.ops.aten.max_pool2d.default]
)
]

Expand Down
5 changes: 5 additions & 0 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ class EdgeCompileConfig:
# TODO(larryliu): remove this
_use_edge_ops: bool = True
# Allow core ATen ops check to be skipped for certain ops, but continue with the rest of the checks.
# Note: only use this for core ATen ops that are missing decompositions. This is temporary,
# enabling verification on the rest of the program until decomposition coverage is improved.
_core_aten_ops_exception_list: List[torch._ops.OpOverload] = field(
default_factory=list
)
# Allow ops to be preserved in the graph, i.e., prevent them from being decomposed.
# These may be core or non-core ATen ops; custom ops should not be here.
_preserve_ops: List[torch.torch._ops.OpOverload] = field(default_factory=list)
# TODO(gasoonjia): remove this
_skip_dim_order: bool = False

Expand Down
56 changes: 39 additions & 17 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,9 +795,19 @@ def _generate_edge_program(
name: str,
config: EdgeCompileConfig,
program: ExportedProgram,
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
) -> ExportedProgram:

"""
Args:
name: The name of the program.
config: The configuration for the edge program.
program: The exported program to be converted to an edge program.
core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
preserve_ops: A list of aten ops that should not be decomposed.
Returns:
An ExportedProgram in edge dialect.
"""
# Remove invalid assert ops, such as _assert_tensor_metadata
gm = program.graph_module
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
Expand All @@ -812,7 +822,8 @@ def _generate_edge_program(
EXIRATenDialectVerifier(
edge_compile_config=config,
class_only=False,
exception_list=ops_set_to_not_decompose,
core_aten_ops_exception_list=core_aten_ops_exception_list,
preserve_ops=preserve_ops,
)(gm)
except ExportError as e:
logging.info(f"Input program {name} is not in ATen dialect.")
Expand Down Expand Up @@ -848,7 +859,8 @@ def _generate_edge_program(
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
exception_list=ops_set_to_not_decompose,
core_aten_ops_exception_list=core_aten_ops_exception_list,
preserve_ops=preserve_ops,
)
],
)
Expand All @@ -864,7 +876,7 @@ def _replace_aten_ops_with_transformed_ops(
program: ExportedProgram,
partitioner,
):
ops_to_not_decompose = set()
preserve_ops = set()
partitioners = partitioner.get(name)
if partitioners is None:
return
Expand All @@ -889,7 +901,7 @@ def _replace_aten_ops_with_transformed_ops(
and node.target in ops_set_to_not_decompose
and is_op_supported
):
ops_to_not_decompose.add(node.target)
preserve_ops.add(node.target)
node.target = aten_op_to_transform_op[node.target]

for _, submod, _ in get_control_flow_submodules(program.graph_module):
Expand All @@ -900,10 +912,10 @@ def _replace_aten_ops_with_transformed_ops(
and node.target in ops_set_to_not_decompose
and is_op_supported
):
ops_to_not_decompose.add(node.target)
preserve_ops.add(node.target)
node.target = aten_op_to_transform_op[node.target]

return ops_to_not_decompose
return preserve_ops


def _restore_transformed_ops_to_aten_ops(program: ExportedProgram):
Expand Down Expand Up @@ -1014,7 +1026,7 @@ def _sanity_check_graph_for_non_decomp_ops(


def _remove_invalid_ops_for_not_decompose(
ops_to_not_decompose: List[torch._ops.OpOverload],
preserve_ops: List[torch._ops.OpOverload],
) -> List[torch._ops.OpOverload]:
_logged_warnings = set()

Expand Down Expand Up @@ -1079,7 +1091,7 @@ def keep(op):
return False
return True

return list(filter(keep, ops_to_not_decompose))
return list(filter(keep, preserve_ops))


def _gen_edge_manager_for_partitioners(
Expand Down Expand Up @@ -1136,7 +1148,7 @@ def _gen_edge_manager_for_partitioners(
name,
config,
program,
list(ops_set_to_not_decompose_by_program.get(name, [])),
preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])),
)

edge_manager = EdgeProgramManager(
Expand Down Expand Up @@ -1281,7 +1293,7 @@ def to_edge_transform_and_lower(
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
exception_list=list(ops_set_to_not_decompose),
preserve_ops=list(ops_set_to_not_decompose),
)()(program.graph_module)

return edge_manager
Expand Down Expand Up @@ -1328,7 +1340,7 @@ def to_edge_with_preserved_ops(
table.pop(op, None)
program = program.run_decompositions(table)
edge_programs[name] = _generate_edge_program(
name, config, program, list(preserve_ops)
name, config, program, preserve_ops=list(preserve_ops)
)

return EdgeProgramManager(
Expand Down Expand Up @@ -1367,8 +1379,16 @@ def to_edge(

for name, program in aten_programs.items():
# Decompose to Core ATen
program = program.run_decompositions(_default_decomposition_table())
edge_programs[name] = _generate_edge_program(name, config, program)
table = _default_decomposition_table()
preserve_ops = []
if compile_config:
preserve_ops = compile_config._preserve_ops
for op in compile_config._preserve_ops:
table.pop(op, None)
program = program.run_decompositions(table)
edge_programs[name] = _generate_edge_program(
name, config, program, preserve_ops=preserve_ops
)

return EdgeProgramManager(edge_programs, constant_methods, config)

Expand All @@ -1389,7 +1409,8 @@ def __init__(
edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
constant_methods: Optional[Dict[str, Any]] = None,
compile_config: Optional[EdgeCompileConfig] = None,
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
):
"""
Should not be called directly by users. User should use :func:'to_edge' instead.
Expand All @@ -1404,7 +1425,8 @@ def __init__(
try:
EXIREdgeDialectVerifier(
edge_compile_config=self.compile_config,
exception_list=ops_set_to_not_decompose,
core_aten_ops_exception_list=core_aten_ops_exception_list,
preserve_ops=preserve_ops,
)(program.graph_module)
except ExportError as e:
logging.info(f"Input program {name} is not in aten dialect.")
Expand Down
5 changes: 3 additions & 2 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
ExecutorchProgramManager,
to_edge,
to_edge_transform_and_lower,
to_edge_with_preserved_ops,
)
from executorch.exir.tracer import _default_decomposition_table
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
Expand Down Expand Up @@ -784,7 +783,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def _test_to_edge_with_preserved_ops(
self, program, preserved_ops, expected_preserved_ops
):
edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)
edge = to_edge(
program, compile_config=EdgeCompileConfig(_preserve_ops=preserved_ops)
)

def count_nodes(graph_module, target):
count = 0
Expand Down
14 changes: 14 additions & 0 deletions exir/verification/test/test_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,17 @@ def forward(self, input, label):
edge_verifier = EXIREdgeDialectVerifier()

edge_verifier(edge.exported_program())

def test_verifier_preserve_ops_view(self) -> None:
class TestExpand(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.expand(2, 2, 2, 2)

model = TestExpand()
config = EdgeCompileConfig(_preserve_ops=[torch.ops.aten.expand.default])
export_model = export(model, (torch.randn(2, 2, 2, 2),), strict=True)
with self.assertRaises(RuntimeError):
to_edge(export_model, compile_config=config)
77 changes: 59 additions & 18 deletions exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-unsafe

import itertools
import logging
import operator
import types
from contextlib import nullcontext
Expand Down Expand Up @@ -81,26 +84,33 @@ def __call__(self, *args, **kwargs):
def EXIRATenDialectVerifier( # noqa: C901
edge_compile_config: Optional[EdgeCompileConfig] = None,
class_only: bool = False,
exception_list: Optional[List[torch._ops.OpOverload]] = None,
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
):
"""
Returns a verifier class that runs ATen dialect specific checks on the graph module.
"""
_core_aten_ops_exception_list = core_aten_ops_exception_list or []
_preserve_ops = preserve_ops or []
# merge the exception list from edge_compile_config and exception_list
if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
exception_list = edge_compile_config._core_aten_ops_exception_list + (
exception_list or []
)
if edge_compile_config:
if edge_compile_config._core_aten_ops_exception_list:
_core_aten_ops_exception_list.extend(
edge_compile_config._core_aten_ops_exception_list
)
if edge_compile_config._preserve_ops:
_preserve_ops.extend(edge_compile_config._preserve_ops)

class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
dialect = "OLD_EXIR_ATEN"

def __init__(self) -> None:
super().__init__()
# Note: here we are using the exception list passed from EXIRATenDialectVerifier function!
self._exception_list = exception_list if exception_list else []
self._core_aten_ops_exception_list = _core_aten_ops_exception_list
self._preserve_ops = _preserve_ops

def _get_exception_list(self) -> List[torch._ops.OpOverload]:
def _get_core_aten_ops_exception_list(self) -> List[torch._ops.OpOverload]:
exception_list = (
[
torch.ops.aten.mkldnn_rnn_layer.default,
Expand All @@ -113,15 +123,35 @@ def _get_exception_list(self) -> List[torch._ops.OpOverload]:
]
+ list(_EXECUTORCH_SYM_OPS)
+ DISALLOW_LIST
+ self._exception_list
+ self._core_aten_ops_exception_list
)

return exception_list

def check_valid_op(self, op):
if isinstance(op, OpOverload):
# TODO These special ops should be removable easily.
if op.namespace != "aten" or op in self._get_exception_list():
if (
op.namespace != "aten"
or op in self._get_core_aten_ops_exception_list()
):
return
if op in self._preserve_ops:
if op.namespace != "aten":
raise RuntimeError(
f"Only preserve aten ops. Received op {op} with namespace {op.namespace}."
)
# Preserved ops should not include mutation or view,
# which may affect memory planning.
if op.is_view:
raise RuntimeError(
f"Cannot preserve operator {op} because it is a view or mutation."
)
if op._schema.is_mutable:
logging.warning(
f"Preserving mutation ops like {op} is a no-op because run_decomposition functionalizes it and prevents it from showing up."
)

return
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
Expand Down Expand Up @@ -149,7 +179,9 @@ def check_valid_op(self, op):
def get_aten_verifier(config: EdgeCompileConfig):
return (
EXIRATenDialectVerifier(
class_only=True, exception_list=config._core_aten_ops_exception_list
class_only=True,
core_aten_ops_exception_list=config._core_aten_ops_exception_list,
preserve_ops=config._preserve_ops,
)
if config._check_ir_validity
else EXIRATenDialectVerifierBase
Expand Down Expand Up @@ -210,13 +242,19 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
def EXIREdgeDialectVerifier( # noqa: C901
edge_compile_config: Optional[EdgeCompileConfig] = None,
class_only: bool = False,
exception_list: Optional[List[torch._ops.OpOverload]] = None,
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
):
_core_aten_ops_exception_list = core_aten_ops_exception_list or []
_preserve_ops = preserve_ops or []
# merge the exception list from edge_compile_config and exception_list
if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
exception_list = edge_compile_config._core_aten_ops_exception_list + (
exception_list or []
)
if edge_compile_config:
if edge_compile_config._core_aten_ops_exception_list:
_core_aten_ops_exception_list.extend(
edge_compile_config._core_aten_ops_exception_list
)
if edge_compile_config._preserve_ops:
_preserve_ops.extend(edge_compile_config._preserve_ops)

class _EXIREdgeDialectVerifier(Verifier):
dialect = "EDGE"
Expand All @@ -228,16 +266,19 @@ def __init__(self) -> None:
self.check_edge_ops = _edge_compile_config._use_edge_ops
self.use_dim_order = not _edge_compile_config._skip_dim_order

self._core_aten_ops_exception_list = _core_aten_ops_exception_list
self._preserve_ops = _preserve_ops

self.aten_op_verifier = EXIRATenDialectVerifier(
exception_list=exception_list
core_aten_ops_exception_list=_core_aten_ops_exception_list,
preserve_ops=_preserve_ops,
)
self.check_valid_aten_op = self.aten_op_verifier.check_valid_op

if self.check_edge_ops:
self.check_valid_op = self.check_valid_edge_op
else:
self.check_valid_op = self.check_valid_aten_op
self._exception_list = exception_list if exception_list else []

def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
return (
Expand All @@ -258,7 +299,7 @@ def check_valid_edge_op(self, op):
in [operator.getitem]
+ DISALLOW_LIST
+ list(_EXECUTORCH_SYM_OPS)
+ self._exception_list
+ self._core_aten_ops_exception_list
):
return

Expand Down
Loading