Skip to content

Commit 32e82bc

Browse files
let backends set info on the lowered_backend.meta from preprocess
Differential Revision: D81466391 Pull Request resolved: #13856
1 parent 9fa7edf commit 32e82bc

File tree

6 files changed

+79
-3
lines changed

6 files changed

+79
-3
lines changed

exir/backend/backend_api.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ def to_backend(
126126
lowered_module.meta = {
127127
"debug_handle_map": preprocess_result.debug_handle_map
128128
}
129+
if preprocess_result._delegate_info_meta is not None:
130+
lowered_module.meta["_delegate_info_meta"] = (
131+
preprocess_result._delegate_info_meta
132+
)
129133
return lowered_module
130134
raise NotImplementedError(f"Backend {backend_id} was not found.")
131135

@@ -610,6 +614,11 @@ def lower_all_submodules_to_backend(
610614
lowered_module.meta = {
611615
"debug_handle_map": preprocess_result.debug_handle_map,
612616
}
617+
if preprocess_result._delegate_info_meta is not None:
618+
assert lowered_module.meta is not None
619+
lowered_module.meta["_delegate_info_meta"] = (
620+
preprocess_result._delegate_info_meta
621+
)
613622
is_submodule = call_submodule_node.meta["is_submodule"]
614623
toplevel_input_specs_to_delete = call_submodule_node.meta[
615624
"toplevel_input_specs_to_delete"

exir/backend/backend_details.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from abc import ABC, abstractmethod
88
from dataclasses import dataclass
99

10-
from typing import Dict, List, Optional, Tuple, Union
10+
from typing import Any, Dict, List, Optional, Tuple, Union
1111

1212
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
1313

@@ -32,6 +32,11 @@ class PreprocessResult:
3232
# but retrieveable by delegates via the NamedDataMap at runtime.
3333
data_store_output: Optional[NamedDataStoreOutput] = None
3434

35+
# Optional delegate-specific information that will be added to the
36+
# lowered_module.meta field in the graph, but not directly serialized
37+
# into the PTE file.
38+
_delegate_info_meta: Optional[Any] = None
39+
3540

3641
"""
3742
How to create a backend (for example, BackendWithCompilerDemo):

exir/backend/test/backend_with_compiler_demo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,5 @@ def preprocess(
138138
encoding="utf8",
139139
),
140140
debug_handle_map=debug_handle_map,
141+
_delegate_info_meta="test",
141142
)

exir/backend/test/test_backends_lifted.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,3 +1264,63 @@ def forward(self, x: List[torch.Tensor]):
12641264

12651265
gm = to_edge(export(ComposedM(), inputs, strict=True))
12661266
gm.exported_program().module()(*inputs)
1267+
1268+
def test_delegate_info_full_delegate(self):
1269+
"""
1270+
Test that _delegate_info_meta from BackendWithCompilerDemo ends up in the call_delegate node metadata
1271+
when using full delegation (to_backend directly).
1272+
"""
1273+
1274+
class SinModule(torch.nn.Module):
1275+
def __init__(self):
1276+
super().__init__()
1277+
1278+
def forward(self, x):
1279+
return torch.sin(x)
1280+
1281+
sin_module = SinModule()
1282+
model_inputs = (torch.ones(1),)
1283+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
1284+
max_value = model_inputs[0].shape[0]
1285+
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
1286+
lowered_sin_module = to_backend(
1287+
"BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs
1288+
)
1289+
1290+
# Check that the lowered module has _delegate_info_meta in its meta
1291+
self.assertIn("_delegate_info_meta", lowered_sin_module.meta.keys())
1292+
self.assertEqual(lowered_sin_module.meta["_delegate_info_meta"], "test")
1293+
1294+
def test_delegate_info_partitioner(self):
1295+
"""
1296+
Test that _delegate_info_meta from BackendWithCompilerDemo ends up in the call_delegate node metadata
1297+
when using partitioner-based delegation.
1298+
"""
1299+
1300+
class SinModule(torch.nn.Module):
1301+
def __init__(self):
1302+
super().__init__()
1303+
1304+
def forward(self, x):
1305+
return torch.sin(x)
1306+
1307+
sin_module = SinModule()
1308+
model_inputs = (torch.ones(1),)
1309+
max_value = model_inputs[0].shape[0]
1310+
1311+
partitioner = AllNodePartitioner(
1312+
"BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))]
1313+
)
1314+
1315+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
1316+
lowered_m = edgeir_m.to_backend(partitioner)
1317+
1318+
# Check that the lowered submodule has _delegate_info_meta in its meta
1319+
lowered_submodules = get_lowered_submodules(
1320+
lowered_m.exported_program().graph_module
1321+
)
1322+
self.assertEqual(len(lowered_submodules), 1)
1323+
1324+
lowered_module = lowered_submodules[0][1]
1325+
self.assertIn("_delegate_info_meta", lowered_module.meta)
1326+
self.assertEqual(lowered_module.meta["_delegate_info_meta"], "test")

exir/emit/_emitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,7 @@ def _add_delegate_map(
10301030
code, module hierarchy etc.
10311031
"""
10321032
delegate_map = {}
1033-
if hasattr(lowered_module, "meta"):
1033+
if lowered_module.meta is not None:
10341034
delegate_map = lowered_module.meta.get("debug_handle_map", {})
10351035

10361036
self.instr_id_to_delegate_debug_id_map[delegate_instruction_id] = {

exir/lowered_backend_module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class LoweredBackendModule(torch.nn.Module):
6666
_named_data_store_output: Optional[
6767
NamedDataStoreOutput
6868
] # Named Data serialized by the backend
69+
meta: Optional[Dict[str, Any]] # Metadata for the lowered module
6970

7071
def __init__(
7172
self,
@@ -81,6 +82,7 @@ def __init__(
8182
self._processed_bytes = processed_bytes
8283
self._compile_specs = compile_specs
8384
self._named_data_store_output = named_data_store_output
85+
self.meta = None
8486

8587
# pyre-ignore
8688
def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":
@@ -109,7 +111,6 @@ def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule"
109111
compile_specs=copy.deepcopy(self._compile_specs, memo),
110112
named_data_store_output=self._named_data_store_output,
111113
)
112-
# pyre-fixme[16]: `LoweredBackendModule` has no attribute `meta`.
113114
res.meta = copy.copy(getattr(self, "meta", {}))
114115
return res
115116

0 commit comments

Comments
 (0)