Skip to content

Commit 3d0edb8

Browse files
committed
[ExecuTorch][to_backend] Enable passing Delegation Spec to to_backend
Support Entire Graph Delegation Flow through EdgeProgramManager's to_backend. ### Motivation A current usecase for backend lowering is through the `to_backend(backend_id, exported_program, compile_spec)` API which lowers the entire exported program to the specified backend_id. However, lowering via the EdgeProgramManager only allows for partitioner based lowering. the EdgeProgramManager is the main component which enables support for multiple methods, as a result backends which leverage the old `to_backend(backend_id, ...)` api can not export executorch models with multiple methods. ### Design We override EdgeProgramManager to also allow Partitioner to be replaceable by DelegationSpec. DelegationSpec is essentially a wrapper around the backend_id and the compile_spec, so any where a partitioenr is specified to lower a graph, the delegation spec can also be used to do entier graph lowering. ### Intended Flow ``` del_spec = DelegationSpec("BackendWithCompilerDemo", [CompileSpec(...)]) encode_graph = torch.export.export(Encoder(), sample_inputs) decode_graph = torch.export.export(Decoder(), sample_inputs) edge_manager = to_edge({ "encode": encode_graph, "decode": decode_graph, }) lowered_edge_manager = edge_manager.to_backend(del_spec) # or if you want to specify which methods to lower to with del_spec lowered_edge_manager= edge_manager.to_backend({ "encode": del_spec, }) ``` Differential Revision: [D69086565](https://our.internmc.facebook.com/intern/diff/D69086565/) [ghstack-poisoned]
1 parent 8c82000 commit 3d0edb8

File tree

8 files changed

+314
-10
lines changed

8 files changed

+314
-10
lines changed

exir/backend/canonical_partitioners/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ runtime.python_library(
77
srcs = [
88
"duplicate_dequant_node_pass.py",
99
"pattern_op_partitioner.py",
10+
"all_node_partitioner.py",
1011
],
1112
visibility = [
1213
"//executorch/...",
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.exir.backend.backend_details import ExportedProgram
11+
from executorch.exir.backend.partitioner import (
12+
DelegationSpec,
13+
Partitioner,
14+
PartitionResult,
15+
)
16+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
17+
18+
def is_non_tensor_placeholder(node: torch.fx.Node, ep: ExportedProgram) -> bool:
19+
"""
20+
Returns true if the node is a placeholder node and it is not a tensor
21+
"""
22+
return node.op == "placeholder" and not (
23+
is_param(ep, node)
24+
or is_buffer(ep, node)
25+
or is_lifted_tensor_constant(ep, node)
26+
)
27+
28+
class AllNodePartitioner(Partitioner):
29+
def __init__(
30+
self,
31+
delegation_spec: DelegationSpec,
32+
):
33+
"""
34+
Partitioner that lowers every single node in the graph module to the
35+
specified backend_id
36+
"""
37+
super().__init__()
38+
self.delegation_spec = delegation_spec
39+
40+
41+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
42+
# tag all nodes
43+
partition_tags: Dict[str, DelegationSpec] = {}
44+
for node in exported_program.graph_module.graph.nodes:
45+
if is_non_tensor_placeholder(node, exported_program) or node.op == "output":
46+
continue
47+
48+
delegation_tag = self.delegation_spec.backend_id
49+
node.meta["delegation_tag"] = delegation_tag
50+
partition_tags[delegation_tag] = self.delegation_spec
51+
52+
return PartitionResult(
53+
tagged_exported_program=exported_program, partition_tags=partition_tags
54+
)

exir/backend/test/test_backends.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import executorch.exir as exir
1212
import torch
13+
from executorch.exir import to_edge
1314
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
1415
from executorch.exir.backend.compile_spec_schema import CompileSpec
1516
from executorch.exir.backend.partitioner import (
@@ -1266,3 +1267,168 @@ def forward(self, x: List[torch.Tensor]):
12661267

12671268
gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
12681269
gm(*inputs)
1270+
1271+
1272+
def test_to_backend_delegation_spec(self):
1273+
class SinModule(torch.nn.Module):
1274+
def __init__(self):
1275+
super().__init__()
1276+
1277+
def forward(self, x):
1278+
return [torch.sin(x)]
1279+
1280+
1281+
sin_module = SinModule()
1282+
model_inputs = (torch.ones(1),)
1283+
max_value = model_inputs[0].shape[0]
1284+
1285+
del_spec = DelegationSpec("BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))])
1286+
1287+
edgeir_m = to_edge(torch.export.export(sin_module, model_inputs))
1288+
edgeir_m = edgeir_m.to_backend(del_spec)
1289+
exec_prog = edgeir_m.to_executorch()
1290+
graph_module = exec_prog.exported_program().graph_module
1291+
# Check that there is not an aten.sin node.
1292+
self.assertTrue(
1293+
exir_ops.edge.aten.sin
1294+
not in {node.target for node in graph_module.graph.nodes}
1295+
)
1296+
1297+
# Check that there exists a call_delegate, representing the call to the
1298+
# delegated function
1299+
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
1300+
graph_module.code
1301+
)
1302+
lowered_submodules = get_lowered_submodules(graph_module)
1303+
self.assertEqual(len(lowered_submodules), 1)
1304+
1305+
for node in graph_module.graph.nodes:
1306+
if node.op == "call_function" and node.target == executorch_call_delegate:
1307+
# Check that first arg is lowered_module_{unique_id}
1308+
self.assertEqual(node.args[0].target, "lowered_module_0")
1309+
1310+
program = exec_prog.executorch_program
1311+
1312+
# Check the program can be printed
1313+
print_program(program)
1314+
1315+
# Check the backend delegate
1316+
self.check_backend_delegate(
1317+
program=program,
1318+
delegate=program.execution_plan[0].delegates[0],
1319+
expected_id=BackendWithCompilerDemo.__name__,
1320+
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
1321+
)
1322+
1323+
# Check the delegate instruction
1324+
self.assertTrue(
1325+
isinstance(
1326+
program.execution_plan[0].chains[0].instructions[0].instr_args,
1327+
DelegateCall,
1328+
)
1329+
)
1330+
buff = exec_prog.buffer
1331+
1332+
executorch_module = _load_for_executorch_from_buffer(buff)
1333+
model_inputs = torch.ones(1)
1334+
model_outputs = executorch_module.forward([model_inputs])
1335+
self.assertEqual(
1336+
model_inputs,
1337+
torch.ones(1),
1338+
)
1339+
expected_output = 0.8333 * torch.ones(1)
1340+
1341+
self.assertTrue(
1342+
torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
1343+
)
1344+
1345+
def test_to_backend_multimethod_delegation_spec(self):
1346+
class SinModule(torch.nn.Module):
1347+
def __init__(self):
1348+
super().__init__()
1349+
1350+
def forward(self, x):
1351+
return torch.sin(x)
1352+
1353+
def inputs(self):
1354+
return (torch.ones(1),)
1355+
1356+
class AddMulModule(torch.nn.Module):
1357+
def __init__(self):
1358+
super().__init__()
1359+
1360+
def forward(self, a, x, b):
1361+
y = torch.mm(a, x)
1362+
z = torch.add(y, b)
1363+
return z
1364+
1365+
def inputs(self):
1366+
return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
1367+
1368+
1369+
sin_module = SinModule()
1370+
max_value_sin = sin_module.inputs()[0].shape[0]
1371+
del_spec_sin = DelegationSpec("BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value_sin]))])
1372+
1373+
add_mul_module = AddMulModule()
1374+
max_value_add_mul = add_mul_module.inputs()[0].shape[0]
1375+
del_spec_add_mul = DelegationSpec("BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value_add_mul]))])
1376+
1377+
edgeir_m = to_edge(
1378+
{
1379+
"sin": torch.export.export(sin_module, sin_module.inputs()),
1380+
"add_mul": torch.export.export(add_mul_module, add_mul_module.inputs()),
1381+
}
1382+
)
1383+
edgeir_m = edgeir_m.to_backend(
1384+
{
1385+
"sin": del_spec_sin,
1386+
"add_mul": del_spec_add_mul,
1387+
}
1388+
)
1389+
exec_prog = edgeir_m.to_executorch()
1390+
1391+
for method_name in ["sin", "add_mul"]:
1392+
graph_module = exec_prog.exported_program(method_name).graph_module
1393+
# Check delegated nodes are gone
1394+
self.assertTrue(
1395+
exir_ops.edge.aten.sin
1396+
not in {node.target for node in graph_module.graph.nodes}
1397+
)
1398+
self.assertTrue(
1399+
exir_ops.edge.aten.add
1400+
not in {node.target for node in graph_module.graph.nodes}
1401+
)
1402+
self.assertTrue(
1403+
exir_ops.edge.aten.mm
1404+
not in {node.target for node in graph_module.graph.nodes}
1405+
)
1406+
# Check that there exists a call_delegate, representing the call to the
1407+
# delegated function
1408+
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
1409+
graph_module.code
1410+
)
1411+
lowered_submodules = get_lowered_submodules(graph_module)
1412+
self.assertEqual(len(lowered_submodules), 1)
1413+
1414+
program = exec_prog.executorch_program
1415+
1416+
# Check the program can be printed
1417+
print_program(program)
1418+
1419+
buff = exec_prog.buffer
1420+
1421+
executorch_module = _load_for_executorch_from_buffer(buff)
1422+
1423+
for method_name, module in {"sin": sin_module, "add_mul": add_mul_module}.items():
1424+
inputs_flattened, _ = tree_flatten(module.inputs())
1425+
model_outputs = executorch_module.run_method(method_name, tuple(inputs_flattened))
1426+
1427+
if method_name == "sin":
1428+
# backend with compiler demo does a taylor approximation of sin
1429+
ref_output = 0.8333 * torch.ones(1)
1430+
else:
1431+
ref_output = module(*module.inputs())
1432+
self.assertTrue(
1433+
torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03)
1434+
)

exir/backend/test/test_backends_lifted.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,14 @@ def forward(self, x):
138138

139139
self.assertTrue(torch.allclose(new_res, expected_res))
140140

141+
# Test same flow but through edge_program_manager
142+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
143+
loweredir_m = edgeir_m.to_backend(DelegationSpec(BackendWithCompilerDemo.__name__, []))
144+
lowered_sin_module = get_lowered_submodules(loweredir_m.exported_program().graph_module)[0][1]
145+
146+
new_res = lowered_sin_module(*model_inputs)[0]
147+
148+
self.assertTrue(torch.allclose(new_res, expected_res))
141149
# TODO(tkaruturi): emitting single LoweredBackendModule
142150
# program = to_edge(export(graph_module)).to_exectorch()._emitter_output.program
143151

exir/backend/test/test_compatibility.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.exir._serialize import _serialize_pte_binary
1212
from executorch.exir.backend.backend_api import to_backend
1313
from executorch.exir.backend.compile_spec_schema import CompileSpec
14+
from executorch.exir.backend.partitioner import DelegationSpec
1415
from executorch.exir.backend.test.backend_with_compiler_demo import (
1516
BackendWithCompilerDemo,
1617
)
@@ -65,3 +66,47 @@ def forward(self, x):
6566
"loading method forward failed with error 0x30",
6667
):
6768
executorch_module = _load_for_executorch_from_buffer(buff)
69+
70+
def test_compatibility_in_runtime_edge_program_manager(self):
71+
class SinModule(torch.nn.Module):
72+
def __init__(self):
73+
super().__init__()
74+
75+
def forward(self, x):
76+
return torch.sin(x)
77+
78+
sin_module = SinModule()
79+
model_inputs = (torch.ones(1),)
80+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
81+
max_value = model_inputs[0].shape[0]
82+
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
83+
lowered_edge_irm = edgeir_m.to_backend(DelegationSpec("BackendWithCompilerDemo", compile_specs))
84+
exec_prog = lowered_edge_irm.to_executorch()
85+
86+
buff = exec_prog.buffer
87+
88+
# The demo backend works well
89+
executorch_module = _load_for_executorch_from_buffer(buff)
90+
model_inputs = torch.ones(1)
91+
_ = executorch_module.forward([model_inputs])
92+
93+
prog = exec_prog.executorch_program
94+
# Rewrite the delegate version number from 0 to 1.
95+
prog.backend_delegate_data[0].data = bytes(
96+
"1version:1#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
97+
encoding="utf8",
98+
)
99+
100+
# Generate the .pte file with the wrong version.
101+
buff = bytes(
102+
_serialize_pte_binary(
103+
program=prog,
104+
)
105+
)
106+
107+
# Throw runtime error with error code 0x30, meaning delegate is incompatible.
108+
with self.assertRaisesRegex(
109+
RuntimeError,
110+
"loading method forward failed with error 0x30",
111+
):
112+
executorch_module = _load_for_executorch_from_buffer(buff)

exir/backend/test/test_delegate_map_builder.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99

1010
import torch
1111
from executorch import exir
12+
from executorch.exir import to_edge
1213
from executorch.exir.backend.backend_api import to_backend
1314
from executorch.exir.backend.test.backend_with_delegate_mapping_demo import (
1415
BackendWithDelegateMappingDemo,
1516
)
17+
from executorch.exir.backend.partitioner import DelegationSpec
18+
from executorch.exir.lowered_backend_module import get_lowered_submodules
1619

1720
from executorch.exir.backend.utils import DelegateMappingBuilder
1821

@@ -162,6 +165,22 @@ def forward(self, x):
162165
composite_model, inputs, exir.CaptureConfig()
163166
).to_edge().to_executorch()
164167

168+
def test_backend_with_delegate_mapping_delegation_spec(self) -> None:
169+
model, inputs = BackendWithDelegateMappingDemo.get_test_model_and_inputs()
170+
edgeir_m = to_edge(torch.export.export(model, inputs))
171+
lowered_m = edgeir_m.to_backend(
172+
DelegationSpec("BackendWithDelegateMappingDemo", [])
173+
)
174+
lowered_submodule = get_lowered_submodules(lowered_m.exported_program().graph_module)[0][1]
175+
debug_handle_map = lowered_submodule.meta.get("debug_handle_map")
176+
self.assertIsNotNone(debug_handle_map)
177+
# There should be 3 backend ops in this model.
178+
self.assertEqual(len(debug_handle_map), 5)
179+
# Check to see that all the delegate debug indexes in the range [0,2] are present.
180+
self.assertTrue(
181+
all(element in debug_handle_map.keys() for element in [1, 2, 3, 4])
182+
)
183+
165184
def test_passing_both_nodes_and_handles(self):
166185
delegate_builder = DelegateMappingBuilder()
167186

exir/program/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ python_library(
3131
"//executorch/exir/_serialize:lib",
3232
"//executorch/exir/backend:backend_api",
3333
"//executorch/exir/backend:partitioner",
34+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
3435
"//executorch/exir/capture:config",
3536
"//executorch/exir/emit:emit",
3637
"//executorch/exir/emit:lib",

0 commit comments

Comments
 (0)