diff --git a/backends/nxp/nxp_backend.py b/backends/nxp/nxp_backend.py index 3233cf6dbd9..6c7a7c77e83 100644 --- a/backends/nxp/nxp_backend.py +++ b/backends/nxp/nxp_backend.py @@ -174,7 +174,7 @@ def preprocess( # Otherwise, we get violation that this op is not part of ATen Core ops. edge_program._verifiers = [ EXIREdgeDialectVerifier( - class_only=True, exception_list=[torch.ops.aten.max_pool2d.default] + class_only=True, core_aten_ops_exception_list=[torch.ops.aten.max_pool2d.default] ) ] diff --git a/exir/capture/_config.py b/exir/capture/_config.py index d66bc24976d..835bc60dad3 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -40,9 +40,14 @@ class EdgeCompileConfig: # TODO(larryliu): remove this _use_edge_ops: bool = True # Allow core ATen ops check to be skipped for certain ops, but continue with the rest of the checks. + # Note: only use this for core ATen ops that are missing decompositions. This is temporary, + # enabling verification on the rest of the program until decomposition coverage is improved. _core_aten_ops_exception_list: List[torch._ops.OpOverload] = field( default_factory=list ) + # Allow ops to be preserved in the graph, i.e., prevent them from being decomposed. + # These may be core or non-core ATen ops; custom ops should not be here. + _preserve_ops: List[torch.torch._ops.OpOverload] = field(default_factory=list) # TODO(gasoonjia): remove this _skip_dim_order: bool = False diff --git a/exir/program/_program.py b/exir/program/_program.py index 8ef02f233ac..3e06d01788a 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -795,9 +795,19 @@ def _generate_edge_program( name: str, config: EdgeCompileConfig, program: ExportedProgram, - ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None, + core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None, + preserve_ops: Optional[List[torch._ops.OpOverload]] = None, ) -> ExportedProgram: - + """ + Args: + name: The name of the program. + config: The configuration for the edge program. + program: The exported program to be converted to an edge program. + core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten. + preserve_ops: A list of aten ops that should not be decomposed. + Returns: + An ExportedProgram in edge dialect. + """ # Remove invalid assert ops, such as _assert_tensor_metadata gm = program.graph_module gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm) @@ -812,7 +822,8 @@ def _generate_edge_program( EXIRATenDialectVerifier( edge_compile_config=config, class_only=False, - exception_list=ops_set_to_not_decompose, + core_aten_ops_exception_list=core_aten_ops_exception_list, + preserve_ops=preserve_ops, )(gm) except ExportError as e: logging.info(f"Input program {name} is not in ATen dialect.") @@ -848,7 +859,8 @@ def _generate_edge_program( EXIREdgeDialectVerifier( edge_compile_config=config, class_only=True, - exception_list=ops_set_to_not_decompose, + core_aten_ops_exception_list=core_aten_ops_exception_list, + preserve_ops=preserve_ops, ) ], ) @@ -864,7 +876,7 @@ def _replace_aten_ops_with_transformed_ops( program: ExportedProgram, partitioner, ): - ops_to_not_decompose = set() + preserve_ops = set() partitioners = partitioner.get(name) if partitioners is None: return @@ -889,7 +901,7 @@ def _replace_aten_ops_with_transformed_ops( and node.target in ops_set_to_not_decompose and is_op_supported ): - ops_to_not_decompose.add(node.target) + preserve_ops.add(node.target) node.target = aten_op_to_transform_op[node.target] for _, submod, _ in get_control_flow_submodules(program.graph_module): @@ -900,10 +912,10 @@ def _replace_aten_ops_with_transformed_ops( and node.target in ops_set_to_not_decompose and is_op_supported ): - ops_to_not_decompose.add(node.target) + preserve_ops.add(node.target) node.target = aten_op_to_transform_op[node.target] - return ops_to_not_decompose + return preserve_ops def _restore_transformed_ops_to_aten_ops(program: ExportedProgram): @@ -1014,7 +1026,7 @@ def _sanity_check_graph_for_non_decomp_ops( def _remove_invalid_ops_for_not_decompose( - ops_to_not_decompose: List[torch._ops.OpOverload], + preserve_ops: List[torch._ops.OpOverload], ) -> List[torch._ops.OpOverload]: _logged_warnings = set() @@ -1079,7 +1091,7 @@ def keep(op): return False return True - return list(filter(keep, ops_to_not_decompose)) + return list(filter(keep, preserve_ops)) def _gen_edge_manager_for_partitioners( @@ -1136,7 +1148,7 @@ def _gen_edge_manager_for_partitioners( name, config, program, - list(ops_set_to_not_decompose_by_program.get(name, [])), + preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])), ) edge_manager = EdgeProgramManager( @@ -1281,7 +1293,7 @@ def to_edge_transform_and_lower( EXIREdgeDialectVerifier( edge_compile_config=config, class_only=True, - exception_list=list(ops_set_to_not_decompose), + preserve_ops=list(ops_set_to_not_decompose), )()(program.graph_module) return edge_manager @@ -1328,7 +1340,7 @@ def to_edge_with_preserved_ops( table.pop(op, None) program = program.run_decompositions(table) edge_programs[name] = _generate_edge_program( - name, config, program, list(preserve_ops) + name, config, program, preserve_ops=list(preserve_ops) ) return EdgeProgramManager( @@ -1367,8 +1379,16 @@ def to_edge( for name, program in aten_programs.items(): # Decompose to Core ATen - program = program.run_decompositions(_default_decomposition_table()) - edge_programs[name] = _generate_edge_program(name, config, program) + table = _default_decomposition_table() + preserve_ops = [] + if compile_config: + preserve_ops = compile_config._preserve_ops + for op in compile_config._preserve_ops: + table.pop(op, None) + program = program.run_decompositions(table) + edge_programs[name] = _generate_edge_program( + name, config, program, preserve_ops=preserve_ops + ) return EdgeProgramManager(edge_programs, constant_methods, config) @@ -1389,7 +1409,8 @@ def __init__( edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]], constant_methods: Optional[Dict[str, Any]] = None, compile_config: Optional[EdgeCompileConfig] = None, - ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None, + core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None, + preserve_ops: Optional[List[torch._ops.OpOverload]] = None, ): """ Should not be called directly by users. User should use :func:'to_edge' instead. @@ -1404,7 +1425,8 @@ def __init__( try: EXIREdgeDialectVerifier( edge_compile_config=self.compile_config, - exception_list=ops_set_to_not_decompose, + core_aten_ops_exception_list=core_aten_ops_exception_list, + preserve_ops=preserve_ops, )(program.graph_module) except ExportError as e: logging.info(f"Input program {name} is not in aten dialect.") diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index d5de78909ce..7173b4d50b5 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -27,7 +27,6 @@ ExecutorchProgramManager, to_edge, to_edge_transform_and_lower, - to_edge_with_preserved_ops, ) from executorch.exir.tracer import _default_decomposition_table from executorch.exir.verification.verifier import EXIREdgeDialectVerifier @@ -784,7 +783,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _test_to_edge_with_preserved_ops( self, program, preserved_ops, expected_preserved_ops ): - edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops) + edge = to_edge( + program, compile_config=EdgeCompileConfig(_preserve_ops=preserved_ops) + ) def count_nodes(graph_module, target): count = 0 diff --git a/exir/verification/test/test_verifier.py b/exir/verification/test/test_verifier.py index 8520d3ce13e..2be4aeac3ab 100644 --- a/exir/verification/test/test_verifier.py +++ b/exir/verification/test/test_verifier.py @@ -161,3 +161,17 @@ def forward(self, input, label): edge_verifier = EXIREdgeDialectVerifier() edge_verifier(edge.exported_program()) + + def test_verifier_preserve_ops_view(self) -> None: + class TestExpand(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.expand(2, 2, 2, 2) + + model = TestExpand() + config = EdgeCompileConfig(_preserve_ops=[torch.ops.aten.expand.default]) + export_model = export(model, (torch.randn(2, 2, 2, 2),), strict=True) + with self.assertRaises(RuntimeError): + to_edge(export_model, compile_config=config) diff --git a/exir/verification/verifier.py b/exir/verification/verifier.py index bc510ff6849..ed304e99fc1 100644 --- a/exir/verification/verifier.py +++ b/exir/verification/verifier.py @@ -3,8 +3,11 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# +# pyre-unsafe import itertools +import logging import operator import types from contextlib import nullcontext @@ -81,16 +84,22 @@ def __call__(self, *args, **kwargs): def EXIRATenDialectVerifier( # noqa: C901 edge_compile_config: Optional[EdgeCompileConfig] = None, class_only: bool = False, - exception_list: Optional[List[torch._ops.OpOverload]] = None, + core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None, + preserve_ops: Optional[List[torch._ops.OpOverload]] = None, ): """ Returns a verifier class that runs ATen dialect specific checks on the graph module. """ + _core_aten_ops_exception_list = core_aten_ops_exception_list or [] + _preserve_ops = preserve_ops or [] # merge the exception list from edge_compile_config and exception_list - if edge_compile_config and edge_compile_config._core_aten_ops_exception_list: - exception_list = edge_compile_config._core_aten_ops_exception_list + ( - exception_list or [] - ) + if edge_compile_config: + if edge_compile_config._core_aten_ops_exception_list: + _core_aten_ops_exception_list.extend( + edge_compile_config._core_aten_ops_exception_list + ) + if edge_compile_config._preserve_ops: + _preserve_ops.extend(edge_compile_config._preserve_ops) class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase): dialect = "OLD_EXIR_ATEN" @@ -98,9 +107,10 @@ class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase): def __init__(self) -> None: super().__init__() # Note: here we are using the exception list passed from EXIRATenDialectVerifier function! - self._exception_list = exception_list if exception_list else [] + self._core_aten_ops_exception_list = _core_aten_ops_exception_list + self._preserve_ops = _preserve_ops - def _get_exception_list(self) -> List[torch._ops.OpOverload]: + def _get_core_aten_ops_exception_list(self) -> List[torch._ops.OpOverload]: exception_list = ( [ torch.ops.aten.mkldnn_rnn_layer.default, @@ -113,7 +123,7 @@ def _get_exception_list(self) -> List[torch._ops.OpOverload]: ] + list(_EXECUTORCH_SYM_OPS) + DISALLOW_LIST - + self._exception_list + + self._core_aten_ops_exception_list ) return exception_list @@ -121,7 +131,27 @@ def _get_exception_list(self) -> List[torch._ops.OpOverload]: def check_valid_op(self, op): if isinstance(op, OpOverload): # TODO These special ops should be removable easily. - if op.namespace != "aten" or op in self._get_exception_list(): + if ( + op.namespace != "aten" + or op in self._get_core_aten_ops_exception_list() + ): + return + if op in self._preserve_ops: + if op.namespace != "aten": + raise RuntimeError( + f"Only preserve aten ops. Received op {op} with namespace {op.namespace}." + ) + # Preserved ops should not include mutation or view, + # which may affect memory planning. + if op.is_view: + raise RuntimeError( + f"Cannot preserve operator {op} because it is a view or mutation." + ) + if op._schema.is_mutable: + logging.warning( + f"Preserving mutation ops like {op} is a no-op because run_decomposition functionalizes it and prevents it from showing up." + ) + return if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags: # NOTE(qihan): whether view_copy operators are marked as canonical is still under @@ -149,7 +179,9 @@ def check_valid_op(self, op): def get_aten_verifier(config: EdgeCompileConfig): return ( EXIRATenDialectVerifier( - class_only=True, exception_list=config._core_aten_ops_exception_list + class_only=True, + core_aten_ops_exception_list=config._core_aten_ops_exception_list, + preserve_ops=config._preserve_ops, ) if config._check_ir_validity else EXIRATenDialectVerifierBase @@ -210,13 +242,19 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None: def EXIREdgeDialectVerifier( # noqa: C901 edge_compile_config: Optional[EdgeCompileConfig] = None, class_only: bool = False, - exception_list: Optional[List[torch._ops.OpOverload]] = None, + core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None, + preserve_ops: Optional[List[torch._ops.OpOverload]] = None, ): + _core_aten_ops_exception_list = core_aten_ops_exception_list or [] + _preserve_ops = preserve_ops or [] # merge the exception list from edge_compile_config and exception_list - if edge_compile_config and edge_compile_config._core_aten_ops_exception_list: - exception_list = edge_compile_config._core_aten_ops_exception_list + ( - exception_list or [] - ) + if edge_compile_config: + if edge_compile_config._core_aten_ops_exception_list: + _core_aten_ops_exception_list.extend( + edge_compile_config._core_aten_ops_exception_list + ) + if edge_compile_config._preserve_ops: + _preserve_ops.extend(edge_compile_config._preserve_ops) class _EXIREdgeDialectVerifier(Verifier): dialect = "EDGE" @@ -228,8 +266,12 @@ def __init__(self) -> None: self.check_edge_ops = _edge_compile_config._use_edge_ops self.use_dim_order = not _edge_compile_config._skip_dim_order + self._core_aten_ops_exception_list = _core_aten_ops_exception_list + self._preserve_ops = _preserve_ops + self.aten_op_verifier = EXIRATenDialectVerifier( - exception_list=exception_list + core_aten_ops_exception_list=_core_aten_ops_exception_list, + preserve_ops=_preserve_ops, ) self.check_valid_aten_op = self.aten_op_verifier.check_valid_op @@ -237,7 +279,6 @@ def __init__(self) -> None: self.check_valid_op = self.check_valid_edge_op else: self.check_valid_op = self.check_valid_aten_op - self._exception_list = exception_list if exception_list else [] def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: return ( @@ -258,7 +299,7 @@ def check_valid_edge_op(self, op): in [operator.getitem] + DISALLOW_LIST + list(_EXECUTORCH_SYM_OPS) - + self._exception_list + + self._core_aten_ops_exception_list ): return