Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ class EdgeCompileConfig:
# TODO(larryliu): remove this
_use_edge_ops: bool = True
# Allow core ATen ops check to be skipped for certain ops, but continue with the rest of the checks.
# Note: only use this for core ATen ops that are missing decompositions. This is temporary,
# enabling verification on the rest of the program until decomposition coverage is improved.
_core_aten_ops_exception_list: List[torch._ops.OpOverload] = field(
default_factory=list
)
# Allow ops to be preserved in the graph, i.e., prevent them from being decomposed.
# These may be core or non-core ATen ops; custom ops should not be here.
_preserve_ops: List[torch.torch._ops.OpOverload] = field(default_factory=list)
# TODO(gasoonjia): remove this
_skip_dim_order: bool = False

Expand Down
56 changes: 39 additions & 17 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,9 +795,19 @@ def _generate_edge_program(
name: str,
config: EdgeCompileConfig,
program: ExportedProgram,
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
) -> ExportedProgram:

"""
Args:
name: The name of the program.
config: The configuration for the edge program.
program: The exported program to be converted to an edge program.
core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
preserve_ops: A list of aten ops that should not be decomposed.
Returns:
An ExportedProgram in edge dialect.
"""
# Remove invalid assert ops, such as _assert_tensor_metadata
gm = program.graph_module
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
Expand All @@ -812,7 +822,8 @@ def _generate_edge_program(
EXIRATenDialectVerifier(
edge_compile_config=config,
class_only=False,
exception_list=ops_set_to_not_decompose,
core_aten_ops_exception_list=core_aten_ops_exception_list,
preserve_ops=preserve_ops,
)(gm)
except ExportError as e:
logging.info(f"Input program {name} is not in ATen dialect.")
Expand Down Expand Up @@ -848,7 +859,8 @@ def _generate_edge_program(
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
exception_list=ops_set_to_not_decompose,
core_aten_ops_exception_list=core_aten_ops_exception_list,
preserve_ops=preserve_ops,
)
],
)
Expand All @@ -864,7 +876,7 @@ def _replace_aten_ops_with_transformed_ops(
program: ExportedProgram,
partitioner,
):
ops_to_not_decompose = set()
preserve_ops = set()
partitioners = partitioner.get(name)
if partitioners is None:
return
Expand All @@ -889,7 +901,7 @@ def _replace_aten_ops_with_transformed_ops(
and node.target in ops_set_to_not_decompose
and is_op_supported
):
ops_to_not_decompose.add(node.target)
preserve_ops.add(node.target)
node.target = aten_op_to_transform_op[node.target]

for _, submod, _ in get_control_flow_submodules(program.graph_module):
Expand All @@ -900,10 +912,10 @@ def _replace_aten_ops_with_transformed_ops(
and node.target in ops_set_to_not_decompose
and is_op_supported
):
ops_to_not_decompose.add(node.target)
preserve_ops.add(node.target)
node.target = aten_op_to_transform_op[node.target]

return ops_to_not_decompose
return preserve_ops


def _restore_transformed_ops_to_aten_ops(program: ExportedProgram):
Expand Down Expand Up @@ -1014,7 +1026,7 @@ def _sanity_check_graph_for_non_decomp_ops(


def _remove_invalid_ops_for_not_decompose(
ops_to_not_decompose: List[torch._ops.OpOverload],
preserve_ops: List[torch._ops.OpOverload],
) -> List[torch._ops.OpOverload]:
_logged_warnings = set()

Expand Down Expand Up @@ -1079,7 +1091,7 @@ def keep(op):
return False
return True

return list(filter(keep, ops_to_not_decompose))
return list(filter(keep, preserve_ops))


def _gen_edge_manager_for_partitioners(
Expand Down Expand Up @@ -1136,7 +1148,7 @@ def _gen_edge_manager_for_partitioners(
name,
config,
program,
list(ops_set_to_not_decompose_by_program.get(name, [])),
preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])),
)

edge_manager = EdgeProgramManager(
Expand Down Expand Up @@ -1281,7 +1293,7 @@ def to_edge_transform_and_lower(
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
exception_list=list(ops_set_to_not_decompose),
preserve_ops=list(ops_set_to_not_decompose),
)()(program.graph_module)

return edge_manager
Expand Down Expand Up @@ -1328,7 +1340,7 @@ def to_edge_with_preserved_ops(
table.pop(op, None)
program = program.run_decompositions(table)
edge_programs[name] = _generate_edge_program(
name, config, program, list(preserve_ops)
name, config, program, preserve_ops=list(preserve_ops)
)

return EdgeProgramManager(
Expand Down Expand Up @@ -1367,8 +1379,16 @@ def to_edge(

for name, program in aten_programs.items():
# Decompose to Core ATen
program = program.run_decompositions(_default_decomposition_table())
edge_programs[name] = _generate_edge_program(name, config, program)
table = _default_decomposition_table()
preserve_ops = []
if compile_config:
preserve_ops = compile_config._preserve_ops
for op in compile_config._preserve_ops:
table.pop(op, None)
program = program.run_decompositions(table)
edge_programs[name] = _generate_edge_program(
name, config, program, preserve_ops=preserve_ops
)

return EdgeProgramManager(edge_programs, constant_methods, config)

Expand All @@ -1389,7 +1409,8 @@ def __init__(
edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
constant_methods: Optional[Dict[str, Any]] = None,
compile_config: Optional[EdgeCompileConfig] = None,
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
):
"""
Should not be called directly by users. User should use :func:'to_edge' instead.
Expand All @@ -1404,7 +1425,8 @@ def __init__(
try:
EXIREdgeDialectVerifier(
edge_compile_config=self.compile_config,
exception_list=ops_set_to_not_decompose,
core_aten_ops_exception_list=core_aten_ops_exception_list,
preserve_ops=preserve_ops,
)(program.graph_module)
except ExportError as e:
logging.info(f"Input program {name} is not in aten dialect.")
Expand Down
5 changes: 3 additions & 2 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
ExecutorchProgramManager,
to_edge,
to_edge_transform_and_lower,
to_edge_with_preserved_ops,
)
from executorch.exir.tracer import _default_decomposition_table
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
Expand Down Expand Up @@ -784,7 +783,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def _test_to_edge_with_preserved_ops(
self, program, preserved_ops, expected_preserved_ops
):
edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)
edge = to_edge(
program, compile_config=EdgeCompileConfig(_preserve_ops=preserved_ops)
)

def count_nodes(graph_module, target):
count = 0
Expand Down
14 changes: 14 additions & 0 deletions exir/verification/test/test_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,17 @@ def forward(self, input, label):
edge_verifier = EXIREdgeDialectVerifier()

edge_verifier(edge.exported_program())

def test_verifier_preserve_ops_view(self) -> None:
class TestExpand(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.expand(2, 2, 2, 2)

model = TestExpand()
config = EdgeCompileConfig(_preserve_ops=[torch.ops.aten.expand.default])
export_model = export(model, (torch.randn(2, 2, 2, 2),), strict=True)
with self.assertRaises(RuntimeError):
to_edge(export_model, compile_config=config)
75 changes: 57 additions & 18 deletions exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import itertools
import logging
import operator
import types
from contextlib import nullcontext
Expand Down Expand Up @@ -81,26 +82,33 @@ def __call__(self, *args, **kwargs):
def EXIRATenDialectVerifier( # noqa: C901
edge_compile_config: Optional[EdgeCompileConfig] = None,
class_only: bool = False,
exception_list: Optional[List[torch._ops.OpOverload]] = None,
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
):
"""
Returns a verifier class that runs ATen dialect specific checks on the graph module.
"""
_core_aten_ops_exception_list = core_aten_ops_exception_list or []
_preserve_ops = preserve_ops or []
# merge the exception list from edge_compile_config and exception_list
if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
exception_list = edge_compile_config._core_aten_ops_exception_list + (
exception_list or []
)
if edge_compile_config:
if edge_compile_config._core_aten_ops_exception_list:
_core_aten_ops_exception_list.extend(
edge_compile_config._core_aten_ops_exception_list
)
if edge_compile_config._preserve_ops:
_preserve_ops.extend(edge_compile_config._preserve_ops)

class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
dialect = "OLD_EXIR_ATEN"

def __init__(self) -> None:
super().__init__()
# Note: here we are using the exception list passed from EXIRATenDialectVerifier function!
self._exception_list = exception_list if exception_list else []
self._core_aten_ops_exception_list = _core_aten_ops_exception_list
self._preserve_ops = _preserve_ops

def _get_exception_list(self) -> List[torch._ops.OpOverload]:
def _get_core_aten_ops_exception_list(self) -> List[torch._ops.OpOverload]:
exception_list = (
[
torch.ops.aten.mkldnn_rnn_layer.default,
Expand All @@ -113,15 +121,35 @@ def _get_exception_list(self) -> List[torch._ops.OpOverload]:
]
+ list(_EXECUTORCH_SYM_OPS)
+ DISALLOW_LIST
+ self._exception_list
+ self._core_aten_ops_exception_list
)

return exception_list

def check_valid_op(self, op):
if isinstance(op, OpOverload):
# TODO These special ops should be removable easily.
if op.namespace != "aten" or op in self._get_exception_list():
if (
op.namespace != "aten"
or op in self._get_core_aten_ops_exception_list()
):
return
if op in self._preserve_ops:
if op.namespace != "aten":
raise RuntimeError(
f"Only preserve aten ops. Received op {op} with namespace {op.namespace}."
)
# Preserved ops should not include mutation or view,
# which may affect memory planning.
if op.is_view:
raise RuntimeError(
f"Cannot preserve operator {op} because it is a view or mutation."
)
if op._schema.is_mutable:
logging.warning(
f"Preserving mutation ops like {op} is a no-op because run_decomposition functionalizes it and prevents it from showing up."
)

return
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
Expand Down Expand Up @@ -149,7 +177,9 @@ def check_valid_op(self, op):
def get_aten_verifier(config: EdgeCompileConfig):
return (
EXIRATenDialectVerifier(
class_only=True, exception_list=config._core_aten_ops_exception_list
class_only=True,
core_aten_ops_exception_list=config._core_aten_ops_exception_list,
preserve_ops=config._preserve_ops,
)
if config._check_ir_validity
else EXIRATenDialectVerifierBase
Expand Down Expand Up @@ -210,13 +240,19 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
def EXIREdgeDialectVerifier( # noqa: C901
edge_compile_config: Optional[EdgeCompileConfig] = None,
class_only: bool = False,
exception_list: Optional[List[torch._ops.OpOverload]] = None,
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
):
_core_aten_ops_exception_list = core_aten_ops_exception_list or []
_preserve_ops = preserve_ops or []
# merge the exception list from edge_compile_config and exception_list
if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
exception_list = edge_compile_config._core_aten_ops_exception_list + (
exception_list or []
)
if edge_compile_config:
if edge_compile_config._core_aten_ops_exception_list:
_core_aten_ops_exception_list.extend(
edge_compile_config._core_aten_ops_exception_list
)
if edge_compile_config._preserve_ops:
_preserve_ops.extend(edge_compile_config._preserve_ops)

class _EXIREdgeDialectVerifier(Verifier):
dialect = "EDGE"
Expand All @@ -228,16 +264,19 @@ def __init__(self) -> None:
self.check_edge_ops = _edge_compile_config._use_edge_ops
self.use_dim_order = not _edge_compile_config._skip_dim_order

self._core_aten_ops_exception_list = _core_aten_ops_exception_list
self._preserve_ops = _preserve_ops

self.aten_op_verifier = EXIRATenDialectVerifier(
exception_list=exception_list
core_aten_ops_exception_list=_core_aten_ops_exception_list,
preserve_ops=_preserve_ops,
)
self.check_valid_aten_op = self.aten_op_verifier.check_valid_op

if self.check_edge_ops:
self.check_valid_op = self.check_valid_edge_op
else:
self.check_valid_op = self.check_valid_aten_op
self._exception_list = exception_list if exception_list else []

def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
return (
Expand All @@ -258,7 +297,7 @@ def check_valid_edge_op(self, op):
in [operator.getitem]
+ DISALLOW_LIST
+ list(_EXECUTORCH_SYM_OPS)
+ self._exception_list
+ self._core_aten_ops_exception_list
):
return

Expand Down
Loading