Skip to content

Commit e7a42e1

Browse files
committed
Update on "[executorch][runtime] Introduce PteDataMap for weight sharing"
PteDataMap is the NamedDataMap that will live in the runtime. It is used to give delegates access to opaque named data stored in the PTE file. Open to alternative naming suggestions, maybe 'PTEDataMap' or 'ProgramDataMap'? **Usage** The PteDataMap is owned by the program, and instantiated at program load time if named_data exists in the PTE file. We introduce usage of 'std::optional' here. I think we can also use executorch::aten::optional to avoid adding standard lib ? When initializing delegates, the PteDataMap is given to delegate_init. Delegates can retrieve opaque delegate data by key using 'get_data'. This gives them a FreeableBuffer that they can free later. **Testing** This test uses the C++ flatbuffer API to build a fake program containing named data. We also creates a temp file with sample data that the data loader can wrap around. TODO: e2e test once delegate aot is ready and we can generate a file with named data. **Note** As the PteDataMap wraps around flatbuffer constructs, the Program must outlive the PteDataMap. PteDataMap does not implement - get_metadata; currently, all data stored is opaque. Later, we can implement get_metadata if a backend stores plain tensor data. - load_into; this is mostly used for the training case, and isn't used by delegates, at least not at the moment Differential Revision: [D70213646](https://our.internmc.facebook.com/intern/diff/D70213646/) [ghstack-poisoned]
2 parents fea62dd + 2827dfb commit e7a42e1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+402
-216
lines changed

CMakeLists.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,14 +248,15 @@ cmake_dependent_option(
248248
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF
249249
)
250250

251-
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
251+
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
252252
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
253+
set(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON)
254+
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
255+
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
253256
endif()
254257

255-
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
256-
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
258+
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
257259
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
258-
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
259260
endif()
260261

261262
if(EXECUTORCH_BUILD_EXTENSION_MODULE)

backends/apple/coreml/TARGETS

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ runtime.python_library(
1414
"@EXECUTORCH_CLIENTS",
1515
],
1616
deps = [
17+
"fbsource//third-party/pypi/coremltools:coremltools",
1718
":executorchcoreml",
1819
"//executorch/exir/backend:backend_details",
1920
"//executorch/exir/backend:compile_spec_schema",
20-
"fbsource//third-party/pypi/coremltools:coremltools",
2121
],
2222
)
2323

@@ -30,13 +30,13 @@ runtime.python_library(
3030
"@EXECUTORCH_CLIENTS",
3131
],
3232
deps = [
33+
"fbsource//third-party/pypi/coremltools:coremltools",
3334
":backend",
3435
"//caffe2:torch",
3536
"//executorch/exir:lib",
3637
"//executorch/exir/backend:compile_spec_schema",
3738
"//executorch/exir/backend:partitioner",
3839
"//executorch/exir/backend:utils",
39-
"fbsource//third-party/pypi/coremltools:coremltools",
4040
],
4141
)
4242

@@ -64,25 +64,23 @@ runtime.cxx_python_extension(
6464
headers = glob([
6565
"runtime/inmemoryfs/**/*.hpp",
6666
]),
67+
base_module = "",
68+
compiler_flags = [
69+
"-std=c++17",
70+
],
6771
preprocessor_flags = [
6872
"-Iexecutorch/backends/apple/coreml/runtime/util",
6973
],
7074
types = [
7175
"executorchcoreml.pyi",
7276
],
73-
compiler_flags = [
74-
"-std=c++17",
75-
],
76-
base_module = "",
7777
visibility = [
7878
"//executorch/examples/apple/coreml/...",
7979
"@EXECUTORCH_CLIENTS",
8080
],
81-
external_deps = [
82-
"pybind11",
83-
],
8481
deps = [
8582
"fbsource//third-party/nlohmann-json:nlohmann-json",
83+
"fbsource//third-party/pybind11:pybind11",
8684
],
8785
)
8886

@@ -92,10 +90,10 @@ runtime.python_test(
9290
"test/*.py",
9391
]),
9492
deps = [
93+
"fbsource//third-party/pypi/pytest:pytest",
9594
":partitioner",
9695
":quantizer",
9796
"//caffe2:torch",
9897
"//pytorch/vision:torchvision",
99-
"fbsource//third-party/pypi/pytest:pytest",
10098
],
10199
)

backends/arm/tosa_mapping.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ def __init__(self, argument: Any) -> None:
107107
if isinstance(argument, (int, float)):
108108
self.__process_number(argument)
109109
return
110+
if isinstance(argument, torch.dtype):
111+
# Dtype is parsed from fake tensor
112+
return
110113

111-
RuntimeError(
114+
raise RuntimeError(
112115
f"Unhandled node input argument: {argument}, of type {type(argument)}"
113116
)

backends/qualcomm/aot/python/targets.bzl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ def define_common_targets():
3333
"//executorch/backends/qualcomm:schema",
3434
"//executorch/backends/qualcomm/aot/ir:qcir_utils",
3535
"//executorch/backends/qualcomm/runtime:runtime",
36+
"fbsource//third-party/pybind11:pybind11",
3637
"fbsource//third-party/qualcomm/qnn/qnn-{0}:api".format(get_qnn_library_verision()),
3738
],
3839
external_deps = [
39-
"pybind11",
4040
"libtorch_python",
4141
],
4242
use_static_deps = True,
@@ -66,10 +66,10 @@ def define_common_targets():
6666
"//executorch/backends/qualcomm:schema",
6767
"//executorch/backends/qualcomm/aot/ir:qcir_utils",
6868
"//executorch/backends/qualcomm/runtime:runtime",
69+
"fbsource//third-party/pybind11:pybind11",
6970
"fbsource//third-party/qualcomm/qnn/qnn-{0}:api".format(get_qnn_library_verision()),
7071
],
7172
external_deps = [
72-
"pybind11",
7373
"libtorch_python",
7474
],
7575
use_static_deps = True,
@@ -93,9 +93,7 @@ def define_common_targets():
9393
"//executorch/backends/qualcomm:schema",
9494
"//executorch/backends/qualcomm/aot/ir:qcir_utils",
9595
"//executorch/backends/qualcomm/runtime:runtime",
96+
"fbsource//third-party/pybind11:pybind11",
9697
"fbsource//third-party/qualcomm/qnn/qnn-{0}:api".format(get_qnn_library_verision()),
9798
],
98-
external_deps = [
99-
"pybind11",
100-
],
10199
)

docs/source/using-executorch-building-from-source.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ portability details.
8080
./install_executorch.sh --pybind off
8181
```
8282

83+
For development, install the package in `--editable` mode, which allows to modify Python source code and see changes reflected immediately.
84+
```
85+
./install_executorch.sh --editable [--pybind xnnpack]
86+
87+
# Or you can directly do the following if dependencies are already installed.
88+
pip install -e .
89+
```
90+
8391
> **_NOTE:_** Cleaning the build system
8492
>
8593
> When fetching a new version of the upstream repo (via `git fetch` or `git

examples/models/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
6464
mismatched_dtypes = [
6565
(key, value.dtype)
6666
for key, value in checkpoint.items()
67-
if value.dtype != dtype
67+
if hasattr(value, "dtype") and value.dtype != dtype
6868
]
6969
if len(mismatched_dtypes) > 0:
7070
print(

examples/models/llama/runner/generation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import time
78
from abc import ABC, abstractmethod
89
from typing import List, Optional
910

@@ -97,6 +98,7 @@ def generate( # noqa: C901
9798
pos_base: int = 0,
9899
) -> List[int]:
99100
# Prefill
101+
prefill_start = time.time()
100102
logits = self.forward(
101103
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
102104
input_pos=(
@@ -105,11 +107,13 @@ def generate( # noqa: C901
105107
else None
106108
),
107109
)
110+
prefill_time = time.time() - prefill_start
108111

109112
current_token = next_token(logits, temperature, top_p)
110113
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
111114
tokens = prompt_tokens + [current_token]
112115

116+
generate_start = time.time()
113117
while len(tokens) < max_seq_len:
114118
if self.use_kv_cache:
115119
logits = self.forward(
@@ -140,6 +144,10 @@ def generate( # noqa: C901
140144
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
141145
print("\n")
142146

147+
generate_time = time.time() - generate_start
148+
print(f"Prefill time: {prefill_time}")
149+
print(f"Generation tok/s: {len(tokens) / generate_time}")
150+
143151
return tokens if echo else tokens[len(prompt_tokens) :]
144152

145153
def text_completion(

exir/_serialize/_named_data_store.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,30 @@ def get_named_data_store_output(self) -> NamedDataStoreOutput:
181181
# Clean up empty maps inside self.external_data
182182
self.external_data = {k: v for k, v in self.external_data.items() if len(v) > 0}
183183
return NamedDataStoreOutput(self.buffers, self.pte_data, self.external_data)
184+
185+
def merge_named_data_store(self, other: NamedDataStoreOutput) -> None:
186+
"""
187+
Merge another NamedDataStore into this one.
188+
Args:
189+
other (NamedDataStore): the other NamedDataStore to merge.
190+
Raises:
191+
ValueError: when the key exists in both stores, and corresponding
192+
data is different between them.
193+
"""
194+
# Merge the pte_data.
195+
for key, buffer_idx in other.pte_data.items():
196+
self.add_named_data(
197+
key,
198+
other.buffers[buffer_idx].buffer,
199+
other.buffers[buffer_idx].alignment,
200+
)
201+
202+
# Merge the external_data.
203+
for filename, key_to_buffer_idx in other.external_data.items():
204+
for key, buffer_idx in key_to_buffer_idx.items():
205+
self.add_named_data(
206+
key,
207+
other.buffers[buffer_idx].buffer,
208+
other.buffers[buffer_idx].alignment,
209+
external_tag=filename,
210+
)

exir/_serialize/test/test_named_data_store.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,62 @@ def test_add_duplicate_key_fail(self) -> None:
8383
self.assertEqual(len(output.pte_data), 1)
8484
self.assertEqual(output.pte_data["key"], 0)
8585
self.assertEqual(len(output.external_data), 0)
86+
87+
def test_merge(self) -> None:
88+
store1 = NamedDataStore()
89+
store1.add_named_data("key1", b"data1", None, None)
90+
store1.add_named_data("key2", b"data2", 16, "file1")
91+
92+
# Check items in the store1.
93+
output = store1.get_named_data_store_output()
94+
self.assertEqual(len(output.buffers), 2)
95+
self.assertEqual(len(output.pte_data), 1)
96+
self.assertEqual(len(output.external_data), 1)
97+
self.assertEqual(len(output.external_data["file1"]), 1)
98+
99+
store2 = NamedDataStore()
100+
store2.add_named_data("key1", b"data1", None, None)
101+
store2.add_named_data("key3", b"data3", None, None)
102+
store2.add_named_data("key4", b"data4", 16, "file1")
103+
store2.add_named_data("key5", b"data5", 16, "file2")
104+
105+
# Check items in store2.
106+
output2 = store2.get_named_data_store_output()
107+
self.assertEqual(len(output2.buffers), 4)
108+
self.assertEqual(len(output2.pte_data), 2)
109+
self.assertEqual(len(output2.external_data), 2)
110+
self.assertEqual(len(output2.external_data["file1"]), 1)
111+
self.assertEqual(len(output2.external_data["file2"]), 1)
112+
113+
# Merge store2 into store1.
114+
store1.merge_named_data_store(output2)
115+
116+
# Check items in store2 are merged into store1.
117+
output = store1.get_named_data_store_output()
118+
# key1, data1 exist in both store1 and store2, so we only have one copy of it.
119+
self.assertEqual(len(output.buffers), 5)
120+
self.assertEqual(len(output.pte_data), 2)
121+
self.assertEqual(len(output.external_data), 2)
122+
self.assertEqual(len(output.external_data["file1"]), 2)
123+
self.assertEqual(len(output.external_data["file2"]), 1)
124+
125+
def test_merge_duplicate_error(self) -> None:
126+
store1 = NamedDataStore()
127+
store1.add_named_data("key1", b"data1", None, None)
128+
129+
# Check items in the store1.
130+
output = store1.get_named_data_store_output()
131+
self.assertEqual(len(output.buffers), 1)
132+
self.assertEqual(len(output.pte_data), 1)
133+
134+
store2 = NamedDataStore()
135+
store2.add_named_data("key1", b"data2", None, None)
136+
137+
# Check items in store2.
138+
output2 = store2.get_named_data_store_output()
139+
self.assertEqual(len(output2.buffers), 1)
140+
self.assertEqual(len(output2.pte_data), 1)
141+
142+
# Merge store2 into store1 raises error as key1 is already in store1
143+
# with different data.
144+
self.assertRaises(ValueError, store1.merge_named_data_store, output2)

exir/backend/backend_api.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -56,9 +57,9 @@ def to_backend(
5657
) -> LoweredBackendModule:
5758
5859
def to_backend(
59-
graph_module: torch.fx.GraphModule,
60-
partitioner: Type[TPartitioner],
61-
) -> torch.fx.GraphModule
60+
edge_program: ExportedProgram,
61+
partitioner: Partitioner,
62+
) -> ExportedProgram:
6263
"""
6364
pass
6465

0 commit comments

Comments
 (0)