Skip to content

Commit 5f76e91

Browse files
committed
Update on "[ExecuTorch][to_backend] Enable passing Delegation Spec to to_backend"
Support Entire Graph Delegation Flow through EdgeProgramManager's to_backend. ### Motivation A current usecase for backend lowering is through the `to_backend(backend_id, exported_program, compile_spec)` API which lowers the entire exported program to the specified backend_id. However, lowering via the EdgeProgramManager only allows for partitioner based lowering. the EdgeProgramManager is the main component which enables support for multiple methods, as a result backends which leverage the old `to_backend(backend_id, ...)` api can not export executorch models with multiple methods. ### Design We override EdgeProgramManager to also allow Partitioner to be replaceable by DelegationSpec. DelegationSpec is essentially a wrapper around the backend_id and the compile_spec, so any where a partitioenr is specified to lower a graph, the delegation spec can also be used to do entier graph lowering. ### Intended Flow ``` del_spec = DelegationSpec("BackendWithCompilerDemo", [CompileSpec(...)]) encode_graph = torch.export.export(Encoder(), sample_inputs) decode_graph = torch.export.export(Decoder(), sample_inputs) edge_manager = to_edge({ "encode": encode_graph, "decode": decode_graph, }) lowered_edge_manager = edge_manager.to_backend(del_spec) # or if you want to specify which methods to lower to with del_spec lowered_edge_manager= edge_manager.to_backend({ "encode": del_spec, }) ``` Differential Revision: [D69086565](https://our.internmc.facebook.com/intern/diff/D69086565/) cc cccclai [ghstack-poisoned]
2 parents 4127895 + 9893a11 commit 5f76e91

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

exir/backend/test/test_backends.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,6 @@ def __init__(self):
12781278

12791279
def forward(self, x):
12801280
return [torch.sin(x)]
1281-
12821281

12831282
sin_module = SinModule()
12841283
model_inputs = (torch.ones(1),)
@@ -1353,7 +1352,7 @@ def __init__(self):
13531352

13541353
def forward(self, x):
13551354
return torch.sin(x)
1356-
1355+
13571356
def inputs(self):
13581357
return (torch.ones(1),)
13591358

@@ -1365,7 +1364,7 @@ def forward(self, a, x, b):
13651364
y = torch.mm(a, x)
13661365
z = torch.add(y, b)
13671366
return z
1368-
1367+
13691368
def inputs(self):
13701369
return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
13711370

0 commit comments

Comments
 (0)