Skip to content

Commit 4134815

Browse files
committed
[ExecuTorch][to_backend] Enable passing Delegation Spec to to_backend
Pull Request resolved: #8165 This will be used for the backend weight sharing so backends which do entire graph delegation can still share data across methods. Support Entire Graph Delegation Flow through EdgeProgramManager's to_backend. ### Motivation A current usecase for backend lowering is through the `to_backend(backend_id, exported_program, compile_spec)` API which lowers the entire exported program to the specified backend_id. However, lowering via the EdgeProgramManager only allows for partitioner based lowering. the EdgeProgramManager is the main component which enables support for multiple methods, as a result backends which leverage the old `to_backend(backend_id, ...)` api can not export executorch models with multiple methods. ### Design We override EdgeProgramManager to also allow Partitioner to be replaceable by DelegationSpec. DelegationSpec is essentially a wrapper around the backend_id and the compile_spec, so any where a partitioenr is specified to lower a graph, the delegation spec can also be used to do entier graph lowering. ### Intended Flow ``` del_spec = DelegationSpec("BackendWithCompilerDemo", [CompileSpec(...)]) encode_graph = torch.export.export(Encoder(), sample_inputs) decode_graph = torch.export.export(Decoder(), sample_inputs) edge_manager = to_edge({ "encode": encode_graph, "decode": decode_graph, }) lowered_edge_manager = edge_manager.to_backend(del_spec) # or if you want to specify which methods to lower to with del_spec lowered_edge_manager= edge_manager.to_backend({ "encode": del_spec, }) ``` ghstack-source-id: 266326224 @exported-using-ghexport Differential Revision: [D69086565](https://our.internmc.facebook.com/intern/diff/D69086565/)
1 parent 0beadcc commit 4134815

File tree

6 files changed

+303
-0
lines changed

6 files changed

+303
-0
lines changed

exir/backend/canonical_partitioners/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ runtime.python_library(
77
srcs = [
88
"duplicate_dequant_node_pass.py",
99
"pattern_op_partitioner.py",
10+
"all_node_partitioner.py",
1011
],
1112
visibility = [
1213
"//executorch/...",
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict, List
8+
9+
import torch
10+
from executorch.exir.backend.backend_details import ExportedProgram
11+
from executorch.exir.backend.compile_spec_schema import CompileSpec
12+
from executorch.exir.backend.partitioner import (
13+
DelegationSpec,
14+
Partitioner,
15+
PartitionResult,
16+
)
17+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
18+
19+
def is_non_tensor_placeholder(node: torch.fx.Node, ep: ExportedProgram) -> bool:
20+
"""
21+
Returns true if the node is a placeholder node and it is not a tensor
22+
"""
23+
return node.op == "placeholder" and not (
24+
is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node)
25+
)
26+
27+
28+
class AllNodePartitioner(Partitioner):
29+
def __init__(
30+
self,
31+
backend_id: str,
32+
compile_specs: List[CompileSpec],
33+
):
34+
"""
35+
Partitioner that lowers every single node in the graph module to the
36+
specified backend_id
37+
"""
38+
super().__init__()
39+
self.delegation_spec = DelegationSpec(backend_id, compile_specs)
40+
41+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
42+
# tag all nodes
43+
partition_tags: Dict[str, DelegationSpec] = {}
44+
for node in exported_program.graph_module.graph.nodes:
45+
if is_non_tensor_placeholder(node, exported_program) or node.op == "output":
46+
continue
47+
48+
delegation_tag = self.delegation_spec.backend_id
49+
node.meta["delegation_tag"] = delegation_tag
50+
partition_tags[delegation_tag] = self.delegation_spec
51+
52+
return PartitionResult(
53+
tagged_exported_program=exported_program, partition_tags=partition_tags
54+
)

exir/backend/test/test_backends.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
import executorch.exir as exir
1212
import torch
13+
from executorch.exir import to_edge
1314
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
15+
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import
16+
(
17+
AllNodePartitioner,
18+
)
1419
from executorch.exir.backend.compile_spec_schema import CompileSpec
1520
from executorch.exir.backend.partitioner import (
1621
DelegationSpec,
@@ -1266,3 +1271,179 @@ def forward(self, x: List[torch.Tensor]):
12661271

12671272
gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
12681273
gm(*inputs)
1274+
1275+
def test_to_backend_delegation_spec(self):
1276+
class SinModule(torch.nn.Module):
1277+
def __init__(self):
1278+
super().__init__()
1279+
1280+
def forward(self, x):
1281+
return [torch.sin(x)]
1282+
1283+
1284+
sin_module = SinModule()
1285+
model_inputs = (torch.ones(1),)
1286+
max_value = model_inputs[0].shape[0]
1287+
1288+
partitioner = AllNodePartitioner(
1289+
"BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))]
1290+
)
1291+
1292+
edgeir_m = to_edge(torch.export.export(sin_module, model_inputs))
1293+
edgeir_m = edgeir_m.to_backend(partitioner)
1294+
exec_prog = edgeir_m.to_executorch()
1295+
graph_module = exec_prog.exported_program().graph_module
1296+
# Check that there is not an aten.sin node.
1297+
self.assertTrue(
1298+
exir_ops.edge.aten.sin
1299+
not in {node.target for node in graph_module.graph.nodes}
1300+
)
1301+
1302+
# Check that there exists a call_delegate, representing the call to the
1303+
# delegated function
1304+
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
1305+
graph_module.code
1306+
)
1307+
lowered_submodules = get_lowered_submodules(graph_module)
1308+
self.assertEqual(len(lowered_submodules), 1)
1309+
1310+
for node in graph_module.graph.nodes:
1311+
if node.op == "call_function" and node.target == executorch_call_delegate:
1312+
# Check that first arg is lowered_module_{unique_id}
1313+
self.assertEqual(node.args[0].target, "lowered_module_0")
1314+
1315+
program = exec_prog.executorch_program
1316+
1317+
# Check the program can be printed
1318+
print_program(program)
1319+
1320+
# Check the backend delegate
1321+
self.check_backend_delegate(
1322+
program=program,
1323+
delegate=program.execution_plan[0].delegates[0],
1324+
expected_id=BackendWithCompilerDemo.__name__,
1325+
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
1326+
)
1327+
1328+
# Check the delegate instruction
1329+
self.assertTrue(
1330+
isinstance(
1331+
program.execution_plan[0].chains[0].instructions[0].instr_args,
1332+
DelegateCall,
1333+
)
1334+
)
1335+
buff = exec_prog.buffer
1336+
1337+
executorch_module = _load_for_executorch_from_buffer(buff)
1338+
model_inputs = torch.ones(1)
1339+
model_outputs = executorch_module.forward([model_inputs])
1340+
self.assertEqual(
1341+
model_inputs,
1342+
torch.ones(1),
1343+
)
1344+
expected_output = 0.8333 * torch.ones(1)
1345+
1346+
self.assertTrue(
1347+
torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
1348+
)
1349+
1350+
def test_to_backend_multimethod_delegation_spec(self):
1351+
class SinModule(torch.nn.Module):
1352+
def __init__(self):
1353+
super().__init__()
1354+
1355+
def forward(self, x):
1356+
return torch.sin(x)
1357+
1358+
def inputs(self):
1359+
return (torch.ones(1),)
1360+
1361+
class AddMulModule(torch.nn.Module):
1362+
def __init__(self):
1363+
super().__init__()
1364+
1365+
def forward(self, a, x, b):
1366+
y = torch.mm(a, x)
1367+
z = torch.add(y, b)
1368+
return z
1369+
1370+
def inputs(self):
1371+
return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
1372+
1373+
sin_module = SinModule()
1374+
max_value_sin = sin_module.inputs()[0].shape[0]
1375+
sin_partitioner = AllNodePartitioner(
1376+
"BackendWithCompilerDemo",
1377+
[CompileSpec("max_value", bytes([max_value_sin]))],
1378+
)
1379+
1380+
add_mul_module = AddMulModule()
1381+
max_value_add_mul = add_mul_module.inputs()[0].shape[0]
1382+
add_mul_partitioner = AllNodePartitioner(
1383+
"BackendWithCompilerDemo",
1384+
[CompileSpec("max_value", bytes([max_value_add_mul]))],
1385+
)
1386+
1387+
edgeir_m = to_edge(
1388+
{
1389+
"sin": torch.export.export(sin_module, sin_module.inputs()),
1390+
"add_mul": torch.export.export(add_mul_module, add_mul_module.inputs()),
1391+
}
1392+
)
1393+
edgeir_m = edgeir_m.to_backend(
1394+
{
1395+
"sin": sin_partitioner,
1396+
"add_mul": add_mul_partitioner,
1397+
}
1398+
)
1399+
exec_prog = edgeir_m.to_executorch()
1400+
1401+
for method_name in ["sin", "add_mul"]:
1402+
graph_module = exec_prog.exported_program(method_name).graph_module
1403+
# Check delegated nodes are gone
1404+
self.assertTrue(
1405+
exir_ops.edge.aten.sin
1406+
not in {node.target for node in graph_module.graph.nodes}
1407+
)
1408+
self.assertTrue(
1409+
exir_ops.edge.aten.add
1410+
not in {node.target for node in graph_module.graph.nodes}
1411+
)
1412+
self.assertTrue(
1413+
exir_ops.edge.aten.mm
1414+
not in {node.target for node in graph_module.graph.nodes}
1415+
)
1416+
# Check that there exists a call_delegate, representing the call to the
1417+
# delegated function
1418+
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
1419+
graph_module.code
1420+
)
1421+
lowered_submodules = get_lowered_submodules(graph_module)
1422+
self.assertEqual(len(lowered_submodules), 1)
1423+
1424+
program = exec_prog.executorch_program
1425+
1426+
# Check the program can be printed
1427+
print_program(program)
1428+
1429+
buff = exec_prog.buffer
1430+
1431+
executorch_module = _load_for_executorch_from_buffer(buff)
1432+
1433+
for method_name, module in {
1434+
"sin": sin_module,
1435+
"add_mul": add_mul_module
1436+
}.items():
1437+
inputs_flattened, _ = tree_flatten(module.inputs())
1438+
model_outputs = executorch_module.run_method(
1439+
method_name, tuple(inputs_flattened)
1440+
)
1441+
1442+
if method_name == "sin":
1443+
# backend with compiler demo does a taylor approximation of sin
1444+
ref_output = 0.8333 * torch.ones(1)
1445+
else:
1446+
ref_output = module(*module.inputs())
1447+
self.assertTrue(
1448+
torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03)
1449+
)

exir/backend/test/test_backends_lifted.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
import torch
1212
from executorch.exir import to_edge
1313
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
14+
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import
15+
(
16+
AllNodePartitioner,
17+
)
1418
from executorch.exir.backend.compile_spec_schema import CompileSpec
1519
from executorch.exir.backend.partitioner import (
1620
DelegationSpec,
@@ -138,6 +142,18 @@ def forward(self, x):
138142

139143
self.assertTrue(torch.allclose(new_res, expected_res))
140144

145+
# Test same flow but through edge_program_manager
146+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
147+
loweredir_m = edgeir_m.to_backend(
148+
AllNodePartitioner(BackendWithCompilerDemo.__name__, [])
149+
)
150+
lowered_sin_module = get_lowered_submodules(
151+
loweredir_m.exported_program().graph_module
152+
)[0][1]
153+
154+
new_res = lowered_sin_module(*model_inputs)[0]
155+
156+
self.assertTrue(torch.allclose(new_res, expected_res))
141157
# TODO(tkaruturi): emitting single LoweredBackendModule
142158
# program = to_edge(export(graph_module)).to_exectorch()._emitter_output.program
143159

exir/backend/test/test_compatibility.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from executorch.exir import to_edge
1111
from executorch.exir._serialize import _serialize_pte_binary
1212
from executorch.exir.backend.backend_api import to_backend
13+
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import
14+
(
15+
AllNodePartitioner,
16+
)
1317
from executorch.exir.backend.compile_spec_schema import CompileSpec
1418
from executorch.exir.backend.test.backend_with_compiler_demo import (
1519
BackendWithCompilerDemo,
@@ -65,3 +69,49 @@ def forward(self, x):
6569
"loading method forward failed with error 0x30",
6670
):
6771
executorch_module = _load_for_executorch_from_buffer(buff)
72+
73+
def test_compatibility_in_runtime_edge_program_manager(self):
74+
class SinModule(torch.nn.Module):
75+
def __init__(self):
76+
super().__init__()
77+
78+
def forward(self, x):
79+
return torch.sin(x)
80+
81+
sin_module = SinModule()
82+
model_inputs = (torch.ones(1),)
83+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
84+
max_value = model_inputs[0].shape[0]
85+
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
86+
lowered_edge_irm = edgeir_m.to_backend(
87+
AllNodePartitioner("BackendWithCompilerDemo", compile_specs)
88+
)
89+
exec_prog = lowered_edge_irm.to_executorch()
90+
91+
buff = exec_prog.buffer
92+
93+
# The demo backend works well
94+
executorch_module = _load_for_executorch_from_buffer(buff)
95+
model_inputs = torch.ones(1)
96+
_ = executorch_module.forward([model_inputs])
97+
98+
prog = exec_prog.executorch_program
99+
# Rewrite the delegate version number from 0 to 1.
100+
prog.backend_delegate_data[0].data = bytes(
101+
"1version:1#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
102+
encoding="utf8",
103+
)
104+
105+
# Generate the .pte file with the wrong version.
106+
buff = bytes(
107+
_serialize_pte_binary(
108+
program=prog,
109+
)
110+
)
111+
112+
# Throw runtime error with error code 0x30, meaning delegate is incompatible.
113+
with self.assertRaisesRegex(
114+
RuntimeError,
115+
"loading method forward failed with error 0x30",
116+
):
117+
executorch_module = _load_for_executorch_from_buffer(buff)

exir/program/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ python_library(
3131
"//executorch/exir/_serialize:lib",
3232
"//executorch/exir/backend:backend_api",
3333
"//executorch/exir/backend:partitioner",
34+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
3435
"//executorch/exir/capture:config",
3536
"//executorch/exir/emit:emit",
3637
"//executorch/exir/emit:lib",

0 commit comments

Comments
 (0)