Skip to content

Commit 698d431

Browse files
committed
[ExecuTorch][Weight Sharing] Track Named Data Store in EdgeProgramManager
We enable Backends to return Named Data by adding NamedDataStoreOutput to the preprocess result. This is a completely BC change, as no backends with an implemented preprocess will see any change if nothing is explicitly implemented. For backend developers to leverage the new NamedDataStore, they can initialize a new NamedDataStore() within preprocess, add_named_data to the data store, and return the NamedDataStore.get_named_data_store_output() in the preprocess result like such: ``` def preprocess(ExportedProgram, List[CompileSpecs]) -> PreprocessResult: named_data_store = NamedDataStore() for node in exported_program.graph.nodes: named_data_store.add_named_data("name", bytes) return PreprocessResult( processed_bytes=bytes, debug_handle_map={}, data_store_output= named_data_store.get_named_data_store_output() ) ``` Under the hood, the data store output is embedded in the loweredbackendmodule, (serializing loweredbackendmodule by itself with the a named_data_store_output is still a todo). But via the EdgeProgramManager path, we add the named_data_store_outputs to the edge_program_manger's named data store to keep track of all the named data returned by backends. Differential Revision: [D70451660](https://our.internmc.facebook.com/intern/diff/D70451660/) ghstack-source-id: 271070690 Pull Request resolved: #9151
1 parent cf8ce89 commit 698d431

File tree

7 files changed

+302
-4
lines changed

7 files changed

+302
-4
lines changed

exir/backend/backend_api.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@
99
import logging
1010
from contextlib import contextmanager, nullcontext
1111
from functools import singledispatch
12-
from typing import Generator, List
12+
from typing import Generator, List, Optional
1313

1414
import torch
1515

16-
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
16+
from executorch.exir.backend.backend_details import (
17+
BackendDetails,
18+
PreprocessResult,
19+
)
20+
from executorch.exir._serialize._named_data_store import NamedDataStore
1721
from executorch.exir.backend.compile_spec_schema import CompileSpec
1822

1923
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
@@ -120,6 +124,7 @@ def to_backend(
120124
backend_id=backend_id,
121125
processed_bytes=preprocess_result.processed_bytes,
122126
compile_specs=compile_specs,
127+
named_data_store_output=preprocess_result.data_store_output
123128
)
124129
lowered_module.meta = {
125130
"debug_handle_map": preprocess_result.debug_handle_map

exir/backend/backend_details.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from executorch.exir.backend.compile_spec_schema import CompileSpec
1313
from torch.export.exported_program import ExportedProgram
14+
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
1415

1516

1617
def enforcedmethod(func):
@@ -24,6 +25,11 @@ class PreprocessResult:
2425
debug_handle_map: Optional[Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]] = (
2526
None
2627
)
28+
# Data Store output created from NamedDataStore.
29+
30+
# Named Data store contains all the named data that is stored in the PTE file,
31+
# but retrieveable by delegates via the NamedDataMap at runtime.
32+
data_store_output: Optional[NamedDataStoreOutput] = None
2733

2834

2935
"""

exir/backend/test/TARGETS

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,62 @@ python_library(
3838
],
3939
)
4040

41+
python_library(
42+
name = "backend_with_named_data_map",
43+
srcs = [
44+
"backend_with_named_data_map.py",
45+
],
46+
visibility = [
47+
"//executorch/...",
48+
"//executorch/test/...",
49+
],
50+
deps = [
51+
"//caffe2:torch",
52+
"//caffe2/functorch:functorch_src",
53+
"//executorch/exir:delegate",
54+
"//executorch/exir:graph_module",
55+
"//executorch/exir:lib",
56+
"//executorch/exir:lowered_backend_module",
57+
"//executorch/exir:print_program",
58+
"//executorch/exir:schema",
59+
"//executorch/exir/backend:backend_api",
60+
"//executorch/exir/backend:compile_spec_schema",
61+
"//executorch/exir/backend:partitioner",
62+
"//executorch/exir/dialects:lib",
63+
"//executorch/extension/pybindings:portable_lib", # @manual
64+
"//executorch/extension/pytree:pylib",
65+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
66+
],
67+
)
68+
69+
python_unittest(
70+
name = "test_backend_with_named_data_map",
71+
srcs = [
72+
"test_backend_with_named_data_map.py",
73+
],
74+
visibility = [
75+
"//executorch/...",
76+
"//executorch/test/...",
77+
],
78+
deps = [
79+
"//caffe2:torch",
80+
"//caffe2/functorch:functorch_src",
81+
"//executorch/exir:delegate",
82+
"//executorch/exir:graph_module",
83+
"//executorch/exir:lib",
84+
"//executorch/exir:lowered_backend_module",
85+
"//executorch/exir:print_program",
86+
"//executorch/exir:schema",
87+
"//executorch/exir/backend:backend_api",
88+
"//executorch/exir/backend:compile_spec_schema",
89+
"//executorch/exir/backend:partitioner",
90+
"//executorch/exir/dialects:lib",
91+
"//executorch/extension/pybindings:portable_lib", # @manual
92+
"//executorch/extension/pytree:pylib",
93+
":backend_with_named_data_map",
94+
],
95+
)
96+
4197
python_library(
4298
name = "qnn_backend_demo",
4399
srcs = [
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import final, List, Dict, Tuple
8+
9+
import torch
10+
11+
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
12+
from executorch.exir.backend.compile_spec_schema import CompileSpec
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from torch.export.exported_program import ExportedProgram
15+
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
16+
generate_pattern_op_partitions,
17+
)
18+
19+
from executorch.exir.backend.compile_spec_schema import CompileSpec
20+
from executorch.exir.backend.partitioner import (
21+
DelegationSpec,
22+
Partitioner,
23+
PartitionResult,
24+
)
25+
from executorch.exir.dialects._ops import ops as exir_ops
26+
from executorch.exir.graph_module import get_control_flow_submodules
27+
from torch.export import ExportedProgram
28+
from torch.fx.passes.operator_support import OperatorSupportBase
29+
from executorch.exir._serialize._named_data_store import NamedDataStore
30+
31+
32+
# Backend details are final (cannot be subclassed).
33+
@final
34+
class BackendWithNamedDataMap(BackendDetails):
35+
"""
36+
Test Backend for Named Data Map Functionality
37+
38+
This backend returns no processed_bytes, instead it uses
39+
the named data store and serializes the name of the op
40+
as the key and the data as its code value
41+
"""
42+
43+
@staticmethod
44+
def preprocess(
45+
edge_program: ExportedProgram,
46+
compile_specs: List[CompileSpec],
47+
) -> PreprocessResult:
48+
op_codes = {
49+
exir_ops.edge.aten.sin.default: 0,
50+
exir_ops.edge.aten.add.Tensor: 1,
51+
exir_ops.edge.aten.sub.Tensor: 2,
52+
exir_ops.edge.aten.mul.Tensor: 3,
53+
exir_ops.edge.aten.div.Tensor: 4
54+
}
55+
ndm = NamedDataStore()
56+
for node in edge_program.graph.nodes:
57+
if node.op == "call_function":
58+
if node.target in op_codes.keys():
59+
ndm.add_named_data(node.target.__name__, bytes(op_codes[node.target]))
60+
61+
62+
return PreprocessResult(
63+
processed_bytes=bytes(b""),
64+
debug_handle_map={},
65+
data_store_output=ndm.get_named_data_store_output(),
66+
)
67+
68+
class SimpleOperatorSupport(OperatorSupportBase):
69+
def is_node_supported(self, submodules, node:torch.fx.Node) -> bool:
70+
return node.op == "call_function" and node.target in [
71+
exir_ops.edge.aten.sin.default,
72+
exir_ops.edge.aten.add.Tensor,
73+
exir_ops.edge.aten.sub.Tensor,
74+
exir_ops.edge.aten.mul.Tensor,
75+
exir_ops.edge.aten.div.Tensor
76+
]
77+
78+
@final
79+
class BackendWithNDMPartitioner(Partitioner):
80+
def __init__(self) -> None:
81+
self._op_support = SimpleOperatorSupport()
82+
self.backend_id = BackendWithNamedDataMap.__name__
83+
84+
def _partition_gm(self, graph_module: torch.fx.GraphModule, id_start:int = 0) -> Tuple[int, Dict[str, DelegationSpec]]:
85+
partition_tags: Dict[str, DelegationSpec] = {}
86+
partition_list = generate_pattern_op_partitions(
87+
graph_module, op_support=self._op_support
88+
)
89+
90+
num_partitions_in_gm = len(partition_list)
91+
for partition in partition_list:
92+
curr_par_id = partition.id or 0
93+
delegation_tag =f"tag_{curr_par_id + id_start}"
94+
for node in partition.nodes:
95+
node.meta["delegation_tag"] = delegation_tag
96+
delegation_spec = DelegationSpec(self.backend_id, [])
97+
partition_tags[delegation_tag] = delegation_spec
98+
99+
start_idx_for_submodules = num_partitions_in_gm
100+
for _, submodule, _ in get_control_flow_submodules(graph_module):
101+
start_idx_for_submodules, ret_partition_tags = self._partition_gm(
102+
submodule, start_idx_for_submodules
103+
)
104+
partition_tags.update(ret_partition_tags)
105+
106+
107+
return start_idx_for_submodules, partition_tags
108+
109+
def partition(self, edge_program: ExportedProgram) -> PartitionResult:
110+
_, partition_tags = self._partition_gm(edge_program.graph_module)
111+
return PartitionResult(
112+
tagged_exported_program=edge_program,
113+
partition_tags=partition_tags,
114+
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.exir import to_edge
12+
13+
from executorch.exir.backend.test.backend_with_named_data_map import (
14+
BackendWithNamedDataMap,
15+
BackendWithNDMPartitioner
16+
)
17+
from executorch.exir.backend.backend_api import to_backend
18+
19+
from torch.testing import FileCheck
20+
from torch.export.exported_program import ExportedProgram
21+
22+
class TestBackendWithNamedDataMap(unittest.TestCase):
23+
def test_lowered_backend_module_has_output(self):
24+
class M(torch.nn.Module):
25+
def forward(self, x):
26+
return x + x
27+
28+
ep = to_edge(torch.export.export(M(), (torch.randn(1, 2),)))
29+
lowered = to_backend(
30+
BackendWithNamedDataMap.__name__, ep.exported_program(), []
31+
)
32+
33+
buffer_entries = lowered.named_data_store_output.buffers
34+
self.assertTrue(len(buffer_entries) == 1)
35+
stored_data = lowered.named_data_store_output.pte_data
36+
37+
self.assertTrue("aten.add.Tensor" in stored_data)
38+
self.assertTrue(buffer_entries[0].buffer == bytes(1))
39+
40+
def test_named_data_with_partitioner(self):
41+
class M(torch.nn.Module):
42+
def forward(self, x):
43+
y = x + x
44+
y = torch.cos(y)
45+
y = y + y
46+
y = torch.sin(y)
47+
return y - y
48+
49+
ep = to_edge(torch.export.export(M(), (torch.randn(1, 2),)))
50+
ep.to_backend(BackendWithNDMPartitioner())
51+
52+
ndm_output = ep._named_data_store.get_named_data_store_output()
53+
buffer_entries = ndm_output.buffers
54+
stored_data =ndm_output.pte_data
55+
self.assertEqual(len(buffer_entries), 3)
56+
self.assertTrue("aten.add.Tensor" in stored_data)
57+
self.assertTrue("aten.sub.Tensor" in stored_data)
58+
self.assertTrue("aten.sin.default" in stored_data)
59+
60+
def test_named_data_with_control_flow(self):
61+
class M(torch.nn.Module):
62+
def true_branch(self, x):
63+
y = x * x
64+
y = torch.cos(y)
65+
return torch.sin(y)
66+
67+
def false_branch(self, x):
68+
return torch.sin(x)
69+
70+
def forward(self, x, y):
71+
z = x/y
72+
z = torch.cond(z > 1, self.true_branch, self.false_branch, [x])
73+
return z - z
74+
75+
ep = to_edge(torch.export.export(M(), (torch.randn(1, 2), torch.randn(1, 2))))
76+
ep.to_backend(BackendWithNDMPartitioner())
77+
78+
ndm_output = ep._named_data_store.get_named_data_store_output()
79+
buffer_entries = ndm_output.buffers
80+
stored_data =ndm_output.pte_data
81+
self.assertEqual(len(buffer_entries), 4)
82+
self.assertTrue("aten.sub.Tensor" in stored_data)
83+
self.assertTrue("aten.div.Tensor" in stored_data)
84+
self.assertTrue("aten.sin.default" in stored_data)
85+
self.assertTrue("aten.mul.Tensor" in stored_data)

exir/lowered_backend_module.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
2424
from executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass
2525
from executorch.exir.schema import Program
26+
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
2627

2728
from executorch.exir.tracer import Value
2829
from torch._library.fake_class_registry import FakeScriptObject
@@ -62,19 +63,22 @@ class LoweredBackendModule(torch.nn.Module):
6263
CompileSpec
6364
] # A list of backend-specific objects with static metadata to configure the "compilation" process.
6465
_original_exported_program: ExportedProgram # The original EXIR module
66+
_named_data_store_output: Optional[NamedDataStoreOutput] # Named Data serialized by the backend
6567

6668
def __init__(
6769
self,
6870
edge_program: ExportedProgram,
6971
backend_id: str,
7072
processed_bytes: bytes,
7173
compile_specs: List[CompileSpec],
74+
named_data_store_output: Optional[NamedDataStoreOutput] = None,
7275
) -> None:
7376
super().__init__()
7477
self._original_exported_program = edge_program
7578
self._backend_id = backend_id
7679
self._processed_bytes = processed_bytes
7780
self._compile_specs = compile_specs
81+
self._named_data_store_output = named_data_store_output
7882

7983
# pyre-ignore
8084
def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":
@@ -133,6 +137,13 @@ def original_module(self) -> ExportedProgram:
133137
Returns the original EXIR module
134138
"""
135139
return self._original_exported_program
140+
141+
@property
142+
def named_data_store_output(self) -> Optional[NamedDataStoreOutput]:
143+
"""
144+
Returns the Named Data Store Output
145+
"""
146+
return self._named_data_store_output
136147

137148
# TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
138149
def buffer(
@@ -154,6 +165,7 @@ def buffer(
154165
segment_alignment=segment_alignment,
155166
constant_tensor_alignment=constant_tensor_alignment,
156167
delegate_alignment=delegate_alignment,
168+
named_data=self.named_data_store_output,
157169
)
158170
)
159171
return out

0 commit comments

Comments
 (0)