Skip to content

Commit b7c8d19

Browse files
committed
[ExecuTorch][to_backend] Enable passing Delegation Spec to to_backend
Pull Request resolved: #8165 This will be used for the backend weight sharing so backends which do entire graph delegation can still share data across methods. Support Entire Graph Delegation Flow through EdgeProgramManager's to_backend. ### Motivation A current usecase for backend lowering is through the `to_backend(backend_id, exported_program, compile_spec)` API which lowers the entire exported program to the specified backend_id. However, lowering via the EdgeProgramManager only allows for partitioner based lowering. the EdgeProgramManager is the main component which enables support for multiple methods, as a result backends which leverage the old `to_backend(backend_id, ...)` api can not export executorch models with multiple methods. ### Design We override EdgeProgramManager to also allow Partitioner to be replaceable by DelegationSpec. DelegationSpec is essentially a wrapper around the backend_id and the compile_spec, so any where a partitioenr is specified to lower a graph, the delegation spec can also be used to do entier graph lowering. ### Intended Flow ``` del_spec = DelegationSpec("BackendWithCompilerDemo", [CompileSpec(...)]) encode_graph = torch.export.export(Encoder(), sample_inputs) decode_graph = torch.export.export(Decoder(), sample_inputs) edge_manager = to_edge({ "encode": encode_graph, "decode": decode_graph, }) lowered_edge_manager = edge_manager.to_backend(del_spec) # or if you want to specify which methods to lower to with del_spec lowered_edge_manager= edge_manager.to_backend({ "encode": del_spec, }) ``` ghstack-source-id: 266547681 @exported-using-ghexport Differential Revision: [D69086565](https://our.internmc.facebook.com/intern/diff/D69086565/)
1 parent 0beadcc commit b7c8d19

File tree

6 files changed

+301
-0
lines changed

6 files changed

+301
-0
lines changed

exir/backend/canonical_partitioners/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ runtime.python_library(
77
srcs = [
88
"duplicate_dequant_node_pass.py",
99
"pattern_op_partitioner.py",
10+
"all_node_partitioner.py",
1011
],
1112
visibility = [
1213
"//executorch/...",
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict, List
8+
9+
import torch
10+
from executorch.exir.backend.backend_details import ExportedProgram
11+
from executorch.exir.backend.compile_spec_schema import CompileSpec
12+
from executorch.exir.backend.partitioner import (
13+
DelegationSpec,
14+
Partitioner,
15+
PartitionResult,
16+
)
17+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
18+
19+
20+
def is_non_tensor_placeholder(node: torch.fx.Node, ep: ExportedProgram) -> bool:
21+
"""
22+
Returns true if the node is a placeholder node and it is not a tensor
23+
"""
24+
return node.op == "placeholder" and not (
25+
is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node)
26+
)
27+
28+
29+
class AllNodePartitioner(Partitioner):
30+
def __init__(
31+
self,
32+
backend_id: str,
33+
compile_specs: List[CompileSpec],
34+
):
35+
"""
36+
Partitioner that lowers every single node in the graph module to the
37+
specified backend_id
38+
"""
39+
super().__init__()
40+
self.delegation_spec = DelegationSpec(backend_id, compile_specs)
41+
42+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
43+
# tag all nodes
44+
partition_tags: Dict[str, DelegationSpec] = {}
45+
for node in exported_program.graph_module.graph.nodes:
46+
if is_non_tensor_placeholder(node, exported_program) or node.op == "output":
47+
continue
48+
49+
delegation_tag = self.delegation_spec.backend_id
50+
node.meta["delegation_tag"] = delegation_tag
51+
partition_tags[delegation_tag] = self.delegation_spec
52+
53+
return PartitionResult(
54+
tagged_exported_program=exported_program, partition_tags=partition_tags
55+
)

exir/backend/test/test_backends.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
import executorch.exir as exir
1212
import torch
13+
from executorch.exir import to_edge
1314
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
15+
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
16+
AllNodePartitioner,
17+
)
1418
from executorch.exir.backend.compile_spec_schema import CompileSpec
1519
from executorch.exir.backend.partitioner import (
1620
DelegationSpec,
@@ -1266,3 +1270,179 @@ def forward(self, x: List[torch.Tensor]):
12661270

12671271
gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
12681272
gm(*inputs)
1273+
1274+
def test_to_backend_delegation_spec(self):
1275+
class SinModule(torch.nn.Module):
1276+
def __init__(self):
1277+
super().__init__()
1278+
1279+
def forward(self, x):
1280+
return [torch.sin(x)]
1281+
1282+
1283+
sin_module = SinModule()
1284+
model_inputs = (torch.ones(1),)
1285+
max_value = model_inputs[0].shape[0]
1286+
1287+
partitioner = AllNodePartitioner(
1288+
"BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))]
1289+
)
1290+
1291+
edgeir_m = to_edge(torch.export.export(sin_module, model_inputs))
1292+
edgeir_m = edgeir_m.to_backend(partitioner)
1293+
exec_prog = edgeir_m.to_executorch()
1294+
graph_module = exec_prog.exported_program().graph_module
1295+
# Check that there is not an aten.sin node.
1296+
self.assertTrue(
1297+
exir_ops.edge.aten.sin
1298+
not in {node.target for node in graph_module.graph.nodes}
1299+
)
1300+
1301+
# Check that there exists a call_delegate, representing the call to the
1302+
# delegated function
1303+
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
1304+
graph_module.code
1305+
)
1306+
lowered_submodules = get_lowered_submodules(graph_module)
1307+
self.assertEqual(len(lowered_submodules), 1)
1308+
1309+
for node in graph_module.graph.nodes:
1310+
if node.op == "call_function" and node.target == executorch_call_delegate:
1311+
# Check that first arg is lowered_module_{unique_id}
1312+
self.assertEqual(node.args[0].target, "lowered_module_0")
1313+
1314+
program = exec_prog.executorch_program
1315+
1316+
# Check the program can be printed
1317+
print_program(program)
1318+
1319+
# Check the backend delegate
1320+
self.check_backend_delegate(
1321+
program=program,
1322+
delegate=program.execution_plan[0].delegates[0],
1323+
expected_id=BackendWithCompilerDemo.__name__,
1324+
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
1325+
)
1326+
1327+
# Check the delegate instruction
1328+
self.assertTrue(
1329+
isinstance(
1330+
program.execution_plan[0].chains[0].instructions[0].instr_args,
1331+
DelegateCall,
1332+
)
1333+
)
1334+
buff = exec_prog.buffer
1335+
1336+
executorch_module = _load_for_executorch_from_buffer(buff)
1337+
model_inputs = torch.ones(1)
1338+
model_outputs = executorch_module.forward([model_inputs])
1339+
self.assertEqual(
1340+
model_inputs,
1341+
torch.ones(1),
1342+
)
1343+
expected_output = 0.8333 * torch.ones(1)
1344+
1345+
self.assertTrue(
1346+
torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
1347+
)
1348+
1349+
def test_to_backend_multimethod_delegation_spec(self):
1350+
class SinModule(torch.nn.Module):
1351+
def __init__(self):
1352+
super().__init__()
1353+
1354+
def forward(self, x):
1355+
return torch.sin(x)
1356+
1357+
def inputs(self):
1358+
return (torch.ones(1),)
1359+
1360+
class AddMulModule(torch.nn.Module):
1361+
def __init__(self):
1362+
super().__init__()
1363+
1364+
def forward(self, a, x, b):
1365+
y = torch.mm(a, x)
1366+
z = torch.add(y, b)
1367+
return z
1368+
1369+
def inputs(self):
1370+
return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
1371+
1372+
sin_module = SinModule()
1373+
max_value_sin = sin_module.inputs()[0].shape[0]
1374+
sin_partitioner = AllNodePartitioner(
1375+
"BackendWithCompilerDemo",
1376+
[CompileSpec("max_value", bytes([max_value_sin]))],
1377+
)
1378+
1379+
add_mul_module = AddMulModule()
1380+
max_value_add_mul = add_mul_module.inputs()[0].shape[0]
1381+
add_mul_partitioner = AllNodePartitioner(
1382+
"BackendWithCompilerDemo",
1383+
[CompileSpec("max_value", bytes([max_value_add_mul]))],
1384+
)
1385+
1386+
edgeir_m = to_edge(
1387+
{
1388+
"sin": torch.export.export(sin_module, sin_module.inputs()),
1389+
"add_mul": torch.export.export(add_mul_module, add_mul_module.inputs()),
1390+
}
1391+
)
1392+
edgeir_m = edgeir_m.to_backend(
1393+
{
1394+
"sin": sin_partitioner,
1395+
"add_mul": add_mul_partitioner,
1396+
}
1397+
)
1398+
exec_prog = edgeir_m.to_executorch()
1399+
1400+
for method_name in ["sin", "add_mul"]:
1401+
graph_module = exec_prog.exported_program(method_name).graph_module
1402+
# Check delegated nodes are gone
1403+
self.assertTrue(
1404+
exir_ops.edge.aten.sin
1405+
not in {node.target for node in graph_module.graph.nodes}
1406+
)
1407+
self.assertTrue(
1408+
exir_ops.edge.aten.add
1409+
not in {node.target for node in graph_module.graph.nodes}
1410+
)
1411+
self.assertTrue(
1412+
exir_ops.edge.aten.mm
1413+
not in {node.target for node in graph_module.graph.nodes}
1414+
)
1415+
# Check that there exists a call_delegate, representing the call to the
1416+
# delegated function
1417+
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
1418+
graph_module.code
1419+
)
1420+
lowered_submodules = get_lowered_submodules(graph_module)
1421+
self.assertEqual(len(lowered_submodules), 1)
1422+
1423+
program = exec_prog.executorch_program
1424+
1425+
# Check the program can be printed
1426+
print_program(program)
1427+
1428+
buff = exec_prog.buffer
1429+
1430+
executorch_module = _load_for_executorch_from_buffer(buff)
1431+
1432+
for method_name, module in {
1433+
"sin": sin_module,
1434+
"add_mul": add_mul_module,
1435+
}.items():
1436+
inputs_flattened, _ = tree_flatten(module.inputs())
1437+
model_outputs = executorch_module.run_method(
1438+
method_name, tuple(inputs_flattened)
1439+
)
1440+
1441+
if method_name == "sin":
1442+
# backend with compiler demo does a taylor approximation of sin
1443+
ref_output = 0.8333 * torch.ones(1)
1444+
else:
1445+
ref_output = module(*module.inputs())
1446+
self.assertTrue(
1447+
torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03)
1448+
)

exir/backend/test/test_backends_lifted.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
import torch
1212
from executorch.exir import to_edge
1313
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
14+
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
15+
AllNodePartitioner,
16+
)
1417
from executorch.exir.backend.compile_spec_schema import CompileSpec
1518
from executorch.exir.backend.partitioner import (
1619
DelegationSpec,
@@ -138,6 +141,18 @@ def forward(self, x):
138141

139142
self.assertTrue(torch.allclose(new_res, expected_res))
140143

144+
# Test same flow but through edge_program_manager
145+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
146+
loweredir_m = edgeir_m.to_backend(
147+
AllNodePartitioner(BackendWithCompilerDemo.__name__, [])
148+
)
149+
lowered_sin_module = get_lowered_submodules(
150+
loweredir_m.exported_program().graph_module
151+
)[0][1]
152+
153+
new_res = lowered_sin_module(*model_inputs)[0]
154+
155+
self.assertTrue(torch.allclose(new_res, expected_res))
141156
# TODO(tkaruturi): emitting single LoweredBackendModule
142157
# program = to_edge(export(graph_module)).to_exectorch()._emitter_output.program
143158

exir/backend/test/test_compatibility.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from executorch.exir import to_edge
1111
from executorch.exir._serialize import _serialize_pte_binary
1212
from executorch.exir.backend.backend_api import to_backend
13+
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
14+
AllNodePartitioner,
15+
)
1316
from executorch.exir.backend.compile_spec_schema import CompileSpec
1417
from executorch.exir.backend.test.backend_with_compiler_demo import (
1518
BackendWithCompilerDemo,
@@ -65,3 +68,49 @@ def forward(self, x):
6568
"loading method forward failed with error 0x30",
6669
):
6770
executorch_module = _load_for_executorch_from_buffer(buff)
71+
72+
def test_compatibility_in_runtime_edge_program_manager(self):
73+
class SinModule(torch.nn.Module):
74+
def __init__(self):
75+
super().__init__()
76+
77+
def forward(self, x):
78+
return torch.sin(x)
79+
80+
sin_module = SinModule()
81+
model_inputs = (torch.ones(1),)
82+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
83+
max_value = model_inputs[0].shape[0]
84+
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
85+
lowered_edge_irm = edgeir_m.to_backend(
86+
AllNodePartitioner("BackendWithCompilerDemo", compile_specs)
87+
)
88+
exec_prog = lowered_edge_irm.to_executorch()
89+
90+
buff = exec_prog.buffer
91+
92+
# The demo backend works well
93+
executorch_module = _load_for_executorch_from_buffer(buff)
94+
model_inputs = torch.ones(1)
95+
_ = executorch_module.forward([model_inputs])
96+
97+
prog = exec_prog.executorch_program
98+
# Rewrite the delegate version number from 0 to 1.
99+
prog.backend_delegate_data[0].data = bytes(
100+
"1version:1#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
101+
encoding="utf8",
102+
)
103+
104+
# Generate the .pte file with the wrong version.
105+
buff = bytes(
106+
_serialize_pte_binary(
107+
program=prog,
108+
)
109+
)
110+
111+
# Throw runtime error with error code 0x30, meaning delegate is incompatible.
112+
with self.assertRaisesRegex(
113+
RuntimeError,
114+
"loading method forward failed with error 0x30",
115+
):
116+
executorch_module = _load_for_executorch_from_buffer(buff)

exir/program/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ python_library(
3131
"//executorch/exir/_serialize:lib",
3232
"//executorch/exir/backend:backend_api",
3333
"//executorch/exir/backend:partitioner",
34+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
3435
"//executorch/exir/capture:config",
3536
"//executorch/exir/emit:emit",
3637
"//executorch/exir/emit:lib",

0 commit comments

Comments
 (0)