diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index 6bb2df3dfdb..dd8d97d66ac 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -126,6 +126,10 @@ def to_backend( lowered_module.meta = { "debug_handle_map": preprocess_result.debug_handle_map } + if preprocess_result._delegate_info_meta is not None: + lowered_module.meta["_delegate_info_meta"] = ( + preprocess_result._delegate_info_meta + ) return lowered_module raise NotImplementedError(f"Backend {backend_id} was not found.") @@ -610,6 +614,11 @@ def lower_all_submodules_to_backend( lowered_module.meta = { "debug_handle_map": preprocess_result.debug_handle_map, } + if preprocess_result._delegate_info_meta is not None: + assert lowered_module.meta is not None + lowered_module.meta["_delegate_info_meta"] = ( + preprocess_result._delegate_info_meta + ) is_submodule = call_submodule_node.meta["is_submodule"] toplevel_input_specs_to_delete = call_submodule_node.meta[ "toplevel_input_specs_to_delete" diff --git a/exir/backend/backend_details.py b/exir/backend/backend_details.py index 513ae7c64b3..6999dadb9f9 100644 --- a/exir/backend/backend_details.py +++ b/exir/backend/backend_details.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from executorch.exir._serialize._named_data_store import NamedDataStoreOutput @@ -32,6 +32,11 @@ class PreprocessResult: # but retrieveable by delegates via the NamedDataMap at runtime. data_store_output: Optional[NamedDataStoreOutput] = None + # Optional delegate-specific information that will be added to the + # lowered_module.meta field in the graph, but not directly serialized + # into the PTE file. + _delegate_info_meta: Optional[Any] = None + """ How to create a backend (for example, BackendWithCompilerDemo): diff --git a/exir/backend/test/backend_with_compiler_demo.py b/exir/backend/test/backend_with_compiler_demo.py index b419db153ee..aa60fa2154c 100644 --- a/exir/backend/test/backend_with_compiler_demo.py +++ b/exir/backend/test/backend_with_compiler_demo.py @@ -138,4 +138,5 @@ def preprocess( encoding="utf8", ), debug_handle_map=debug_handle_map, + _delegate_info_meta="test", ) diff --git a/exir/backend/test/test_backends_lifted.py b/exir/backend/test/test_backends_lifted.py index b6aea7f8bb3..7d5eea3f0f4 100644 --- a/exir/backend/test/test_backends_lifted.py +++ b/exir/backend/test/test_backends_lifted.py @@ -1264,3 +1264,63 @@ def forward(self, x: List[torch.Tensor]): gm = to_edge(export(ComposedM(), inputs, strict=True)) gm.exported_program().module()(*inputs) + + def test_delegate_info_full_delegate(self): + """ + Test that _delegate_info_meta from BackendWithCompilerDemo ends up in the call_delegate node metadata + when using full delegation (to_backend directly). + """ + + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + sin_module = SinModule() + model_inputs = (torch.ones(1),) + edgeir_m = to_edge(export(sin_module, model_inputs, strict=True)) + max_value = model_inputs[0].shape[0] + compile_specs = [CompileSpec("max_value", bytes([max_value]))] + lowered_sin_module = to_backend( + "BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs + ) + + # Check that the lowered module has _delegate_info_meta in its meta + self.assertIn("_delegate_info_meta", lowered_sin_module.meta.keys()) + self.assertEqual(lowered_sin_module.meta["_delegate_info_meta"], "test") + + def test_delegate_info_partitioner(self): + """ + Test that _delegate_info_meta from BackendWithCompilerDemo ends up in the call_delegate node metadata + when using partitioner-based delegation. + """ + + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + sin_module = SinModule() + model_inputs = (torch.ones(1),) + max_value = model_inputs[0].shape[0] + + partitioner = AllNodePartitioner( + "BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))] + ) + + edgeir_m = to_edge(export(sin_module, model_inputs, strict=True)) + lowered_m = edgeir_m.to_backend(partitioner) + + # Check that the lowered submodule has _delegate_info_meta in its meta + lowered_submodules = get_lowered_submodules( + lowered_m.exported_program().graph_module + ) + self.assertEqual(len(lowered_submodules), 1) + + lowered_module = lowered_submodules[0][1] + self.assertIn("_delegate_info_meta", lowered_module.meta) + self.assertEqual(lowered_module.meta["_delegate_info_meta"], "test") diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 5ee8ca56091..80ba389c270 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1030,7 +1030,7 @@ def _add_delegate_map( code, module hierarchy etc. """ delegate_map = {} - if hasattr(lowered_module, "meta"): + if lowered_module.meta is not None: delegate_map = lowered_module.meta.get("debug_handle_map", {}) self.instr_id_to_delegate_debug_id_map[delegate_instruction_id] = { diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index e1dd7cb4079..61414990703 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -66,6 +66,7 @@ class LoweredBackendModule(torch.nn.Module): _named_data_store_output: Optional[ NamedDataStoreOutput ] # Named Data serialized by the backend + meta: Optional[Dict[str, Any]] # Metadata for the lowered module def __init__( self, @@ -81,6 +82,7 @@ def __init__( self._processed_bytes = processed_bytes self._compile_specs = compile_specs self._named_data_store_output = named_data_store_output + self.meta = None # pyre-ignore def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule": @@ -109,7 +111,6 @@ def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule" compile_specs=copy.deepcopy(self._compile_specs, memo), named_data_store_output=self._named_data_store_output, ) - # pyre-fixme[16]: `LoweredBackendModule` has no attribute `meta`. res.meta = copy.copy(getattr(self, "meta", {})) return res