Skip to content

Commit a43894c

Browse files
committed
Reuse types in _named_data_store and support tensor layouts
Pull Request resolved: #14667 Reuse `DataEntry` from data_serializer.py, in _named_data_store.py. Motivation - deserialize from flat tensor to named data store output - support tensor layout in named data store ghstack-source-id: 319685093 Differential Revision: [D83490345](https://our.internmc.facebook.com/intern/diff/D83490345/)
1 parent 09408e3 commit a43894c

File tree

7 files changed

+142
-122
lines changed

7 files changed

+142
-122
lines changed

exir/_serialize/_named_data_store.py

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,32 @@
77
# pyre-strict
88

99
import hashlib
10-
import math
1110
from dataclasses import dataclass
1211

1312
# from dataclasses import dataclass
1413
from typing import Dict, List, Optional
1514

16-
17-
@dataclass
18-
class BufferEntry:
19-
"""A class to hold the buffer entries for serialization.
20-
21-
Attributes:
22-
buffer: The buffer bytes.
23-
alignment: The alignment of the buffer.
24-
"""
25-
26-
buffer: bytes
27-
alignment: int
15+
from executorch.exir._serialize.data_serializer import DataEntry
16+
from executorch.exir.tensor_layout import TensorLayout
2817

2918

3019
@dataclass
3120
class NamedDataStoreOutput:
3221
"""
33-
Holds named data for serialization.
22+
Holds named data for serialization. Note: a DataEntry contains the index into
23+
`buffers`, the alignment and a tensor layout, if applicable.
3424
3525
Attributes:
3626
buffers: A list of unique buffer entries.
3727
pte_data: Contains data that is stored inside the PTE file. A mapping from
38-
{key: buffer_index}.
28+
{key: DataEntry}.
3929
external_data: Contains data that is stored external to the PTE. A mapping
40-
from {filename: {key: buffer_index}}.
30+
from {filename: {key: DataEntry}}.
4131
"""
4232

43-
buffers: List[BufferEntry]
44-
pte_data: Dict[str, int]
45-
external_data: Dict[str, Dict[str, int]]
33+
buffers: List[bytes]
34+
pte_data: Dict[str, DataEntry]
35+
external_data: Dict[str, Dict[str, DataEntry]]
4636

4737

4838
class NamedDataStore:
@@ -61,12 +51,12 @@ class NamedDataStore:
6151
"""
6252

6353
# List of unique blobs.
64-
buffers: List[BufferEntry]
65-
# Named data stored inside the PTE file. Map of {key: buffer_index}.
66-
pte_data: Dict[str, int]
54+
buffers: List[bytes]
55+
# Named data stored inside the PTE file. Map of {key: DataEntry}.
56+
pte_data: Dict[str, DataEntry]
6757
# Named data stored outside of the PTE file.
68-
# Map of {filename: {key: buffer_index}}.
69-
external_data: Dict[str, Dict[str, int]]
58+
# Map of {filename: {key: DataEntry}}.
59+
external_data: Dict[str, Dict[str, DataEntry]]
7060

7161
# Cache of the data hash for deduplication.
7262
# Use a hash instead of the data as a key because a sha256 collision is
@@ -93,7 +83,8 @@ def _add_named_data_to_map(
9383
key: str,
9484
data: bytes,
9585
alignment: int,
96-
local_key_to_buffer_idx: Dict[str, int],
86+
local_key_to_buffer_idx: Dict[str, DataEntry],
87+
tensor_layout: Optional[TensorLayout] = None,
9788
) -> None:
9889
"""
9990
Add data to a map and update the alignment. Ensure that the key-data
@@ -116,33 +107,31 @@ def _add_named_data_to_map(
116107

117108
# Check if the key exists.
118109
buffer_idx = self.key_to_buffer_idx.get(key, -1)
119-
if buffer_idx != -1:
120-
# If the key exists, the corresponding data must be identical.
121-
if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx:
122-
raise ValueError(
123-
f"Duplicate key {key} with different data. "
124-
f"Existing data: {self.buffers[buffer_idx].buffer}. "
125-
f"New data: {data}."
126-
)
127-
self.buffers[buffer_idx].alignment = math.lcm(
128-
self.buffers[buffer_idx].alignment, alignment
110+
# If the key exists, the corresponding data must be identical.
111+
if (
112+
buffer_idx != -1
113+
and self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx
114+
):
115+
raise ValueError(
116+
f"Duplicate key {key} with different data. "
117+
f"Existing data: {self.buffers[buffer_idx]}. "
118+
f"New data: {data}."
129119
)
130120
else:
131121
# Key doesn't exist; check if the data exists.
132122
buffer_idx = self.data_hash_to_buffer_idx.get(hashed, -1)
133-
if buffer_idx != -1:
134-
# The data exists; update the alignment.
135-
self.buffers[buffer_idx].alignment = math.lcm(
136-
self.buffers[buffer_idx].alignment, alignment
137-
)
138-
else:
123+
if buffer_idx == -1:
139124
# The data doesn't exist; add it to the data store.
140125
buffer_idx = len(self.buffers)
141-
self.buffers.append(BufferEntry(data, alignment))
126+
self.buffers.append(data)
142127
self.data_hash_to_buffer_idx[hashed] = buffer_idx
143128

144129
# Add key to the map and the key cache.
145-
local_key_to_buffer_idx[key] = buffer_idx
130+
local_key_to_buffer_idx[key] = DataEntry(
131+
buffer_index=buffer_idx,
132+
alignment=alignment,
133+
tensor_layout=tensor_layout,
134+
)
146135
self.key_to_buffer_idx[key] = buffer_idx
147136

148137
def add_named_data(
@@ -151,6 +140,7 @@ def add_named_data(
151140
data: bytes,
152141
alignment: Optional[int] = 1,
153142
external_tag: Optional[str] = None,
143+
tensor_layout: Optional[TensorLayout] = None,
154144
) -> None:
155145
"""
156146
Adds a named blob to the NamedDataStore.
@@ -159,6 +149,7 @@ def add_named_data(
159149
data (bytes): Bytes being requested to be serialized.
160150
alignment (int): alignment for bytes to be serialized with.
161151
external (Optional[str]): the external filename that this data is saved to.
152+
tensor_layout (Optional[TensorLayout]): layout of the tensor, if applicable.
162153
Raises:
163154
ValueError: when the key exists in the store, and corresponding data
164155
is different.
@@ -171,10 +162,16 @@ def add_named_data(
171162
raise ValueError(f"Alignment must be greater than 0, received {alignment}.")
172163

173164
if external_tag is None:
174-
self._add_named_data_to_map(key, data, alignment, self.pte_data)
165+
self._add_named_data_to_map(
166+
key, data, alignment, self.pte_data, tensor_layout
167+
)
175168
else:
176169
self._add_named_data_to_map(
177-
key, data, alignment, self.external_data.setdefault(external_tag, {})
170+
key,
171+
data,
172+
alignment,
173+
self.external_data.setdefault(external_tag, {}),
174+
tensor_layout,
178175
)
179176

180177
def get_named_data_store_output(self) -> NamedDataStoreOutput:
@@ -192,19 +189,22 @@ def merge_named_data_store(self, other: NamedDataStoreOutput) -> None:
192189
data is different between them.
193190
"""
194191
# Merge the pte_data.
195-
for key, buffer_idx in other.pte_data.items():
192+
for key, data_entry in other.pte_data.items():
196193
self.add_named_data(
197194
key,
198-
other.buffers[buffer_idx].buffer,
199-
other.buffers[buffer_idx].alignment,
195+
other.buffers[data_entry.buffer_index],
196+
data_entry.alignment,
197+
external_tag=None,
198+
tensor_layout=data_entry.tensor_layout,
200199
)
201200

202201
# Merge the external_data.
203-
for filename, key_to_buffer_idx in other.external_data.items():
204-
for key, buffer_idx in key_to_buffer_idx.items():
202+
for filename, key_to_data_entry in other.external_data.items():
203+
for key, data_entry in key_to_data_entry.items():
205204
self.add_named_data(
206205
key,
207-
other.buffers[buffer_idx].buffer,
208-
other.buffers[buffer_idx].alignment,
206+
other.buffers[data_entry.buffer_index],
207+
data_entry.alignment,
209208
external_tag=filename,
209+
tensor_layout=data_entry.tensor_layout,
210210
)

exir/_serialize/_program.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import re
1313

1414
from dataclasses import dataclass
15-
from typing import ClassVar, Dict, List, Literal, Optional, Tuple
15+
from typing import ClassVar, Dict, List, Literal, Optional, Sequence, Tuple
1616

1717
from executorch.exir._serialize._cord import Cord
1818
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
@@ -21,10 +21,9 @@
2121
_program_flatbuffer_to_json,
2222
_program_json_to_flatbuffer,
2323
)
24-
from executorch.exir._serialize._named_data_store import (
25-
BufferEntry,
26-
NamedDataStoreOutput,
27-
)
24+
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
25+
26+
from executorch.exir._serialize.data_serializer import DataEntry
2827

2928
from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
3029

@@ -368,8 +367,8 @@ def _extract_constant_segment(
368367
def _extract_named_data(
369368
program: Program,
370369
segments: List[AlignedData],
371-
buffers: List[BufferEntry],
372-
name_to_buffer_idx: Dict[str, int],
370+
buffers: Sequence[bytes],
371+
name_to_data_entry: Dict[str, DataEntry],
373372
) -> None:
374373
"""Modifies the program in-place to add references to the named data
375374
segments.
@@ -379,7 +378,7 @@ def _extract_named_data(
379378
segments: A list of buffers to append extracted segments to. Modified in-place.
380379
buffers: A list of unique buffers and the information required to
381380
serialize them. Not modified.
382-
name_to_buffer_idx: A map from the name of a blob to the index in buffers.
381+
name_to_data_entry: A map from the blob name to DataEntry.
383382
Not modified.
384383
"""
385384
if program.named_data is not None and len(program.named_data) > 0:
@@ -389,14 +388,14 @@ def _extract_named_data(
389388
segment_index_map: Dict[int, int] = {}
390389

391390
named_data: List[NamedData] = []
392-
for name, buffer_idx in name_to_buffer_idx.items():
393-
segment_index = segment_index_map.get(buffer_idx, None)
391+
for name, data_entry in name_to_data_entry.items():
392+
segment_index = segment_index_map.get(data_entry.buffer_index, None)
394393
if segment_index is None:
395394
segment_index = len(segments)
396-
segment_index_map[buffer_idx] = segment_index
395+
segment_index_map[data_entry.buffer_index] = segment_index
397396
segments.append(
398397
AlignedData(
399-
Cord(buffers[buffer_idx].buffer), buffers[buffer_idx].alignment
398+
Cord(buffers[data_entry.buffer_index]), data_entry.alignment
400399
)
401400
)
402401
named_data.append(NamedData(key=name, segment_index=segment_index))

exir/_serialize/_serialize.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,18 @@ def serialize_for_executorch(
117117
)
118118
buffers.append(emitter_output.external_constant_buffer[index])
119119

120-
# Extract external data.
120+
# Extract external data from named_data_store.
121121
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
122-
key_to_buffer_index = named_data_store.external_data.get(tag, {})
123-
for key, index in key_to_buffer_index.items():
122+
blob_to_data_entry = named_data_store.external_data.get(tag, {})
123+
for key, data_entry in blob_to_data_entry.items():
124124
assert key not in key_to_data_entry # key must be unique
125125
key_to_data_entry[key] = DataEntry(
126126
buffer_index=len(buffers),
127-
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
128-
alignment=named_data_store.buffers[index].alignment,
129-
tensor_layout=None,
127+
alignment=data_entry.alignment,
128+
tensor_layout=data_entry.tensor_layout,
130129
)
131-
buffers.append(named_data_store.buffers[index].buffer)
130+
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
131+
buffers.append(named_data_store.buffers[data_entry.buffer_index])
132132

133133
# Serialize into PTD file.
134134
ptd_files[tag] = data_serializer.serialize(

exir/_serialize/test/test_named_data_store.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
import unittest
1010

11-
from executorch.exir._serialize._named_data_store import BufferEntry, NamedDataStore
11+
from executorch.exir._serialize._named_data_store import NamedDataStore
12+
from executorch.exir._serialize.data_serializer import DataEntry
13+
from executorch.exir.scalar_type import ScalarType
14+
from executorch.exir.tensor_layout import TensorLayout
1215

1316

1417
class TestNamedDataStore(unittest.TestCase):
@@ -21,17 +24,17 @@ def test_add(self) -> None:
2124
output = store.get_named_data_store_output()
2225

2326
self.assertEqual(len(output.buffers), 3)
24-
self.assertEqual(output.buffers[0], BufferEntry(b"data1", 1))
25-
self.assertEqual(output.buffers[1], BufferEntry(b"data2", 16))
26-
self.assertEqual(output.buffers[2], BufferEntry(b"data3", 16))
27+
self.assertEqual(output.buffers[0], b"data1")
28+
self.assertEqual(output.buffers[1], b"data2")
29+
self.assertEqual(output.buffers[2], b"data3")
2730

2831
self.assertEqual(len(output.pte_data), 1)
29-
self.assertEqual(output.pte_data["key1"], 0)
32+
self.assertEqual(output.pte_data["key1"], DataEntry(0, 1, None))
3033

3134
self.assertEqual(len(output.external_data), 1)
3235
self.assertEqual(len(output.external_data["file1"]), 2)
33-
self.assertEqual(output.external_data["file1"]["key2"], 1)
34-
self.assertEqual(output.external_data["file1"]["key3"], 2)
36+
self.assertEqual(output.external_data["file1"]["key2"], DataEntry(1, 16, None))
37+
self.assertEqual(output.external_data["file1"]["key3"], DataEntry(2, 16, None))
3538

3639
def test_add_duplicate_name_and_data(self) -> None:
3740
store = NamedDataStore()
@@ -41,10 +44,10 @@ def test_add_duplicate_name_and_data(self) -> None:
4144
output = store.get_named_data_store_output()
4245

4346
self.assertEqual(len(output.buffers), 1)
44-
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))
47+
self.assertEqual(output.buffers[0], b"data")
4548

4649
self.assertEqual(len(output.pte_data), 1)
47-
self.assertEqual(output.pte_data["key"], 0)
50+
self.assertEqual(output.pte_data["key"], DataEntry(0, 1, None))
4851

4952
self.assertEqual(len(output.external_data), 0)
5053

@@ -56,12 +59,11 @@ def test_add_same_data_with_different_alignment(self) -> None:
5659
output = store.get_named_data_store_output()
5760

5861
self.assertEqual(len(output.buffers), 1)
59-
# Check that we take the LCM of the two alignments (3, 4) = 12
60-
self.assertEqual(output.buffers[0], BufferEntry(b"data", 12))
62+
self.assertEqual(output.buffers[0], b"data")
6163

6264
self.assertEqual(len(output.pte_data), 2)
63-
self.assertEqual(output.pte_data["key"], 0)
64-
self.assertEqual(output.pte_data["key1"], 0)
65+
self.assertEqual(output.pte_data["key"], DataEntry(0, 3, None))
66+
self.assertEqual(output.pte_data["key1"], DataEntry(0, 4, None))
6567

6668
self.assertEqual(len(output.external_data), 0)
6769

@@ -78,15 +80,30 @@ def test_add_duplicate_key_fail(self) -> None:
7880
output = store.get_named_data_store_output()
7981

8082
self.assertEqual(len(output.buffers), 1)
81-
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))
83+
self.assertEqual(output.buffers[0], b"data")
8284

8385
self.assertEqual(len(output.pte_data), 1)
84-
self.assertEqual(output.pte_data["key"], 0)
86+
self.assertEqual(output.pte_data["key"], DataEntry(0, 1, None))
8587
self.assertEqual(len(output.external_data), 0)
8688

89+
def test_add_same_data_with_different_tensor_layout(self) -> None:
90+
store = NamedDataStore()
91+
tensor_layout1 = TensorLayout(ScalarType.FLOAT, [1, 2], [0, 1])
92+
tensor_layout2 = TensorLayout(ScalarType.FLOAT, [2, 1], [0, 1])
93+
store.add_named_data("key", b"data", None, None, tensor_layout1)
94+
store.add_named_data("key1", b"data", None, None, tensor_layout2)
95+
96+
output = store.get_named_data_store_output()
97+
self.assertEqual(len(output.buffers), 1)
98+
self.assertEqual(output.buffers[0], b"data")
99+
100+
self.assertEqual(output.pte_data["key"], DataEntry(0, 1, tensor_layout1))
101+
self.assertEqual(output.pte_data["key1"], DataEntry(0, 1, tensor_layout2))
102+
87103
def test_merge(self) -> None:
88104
store1 = NamedDataStore()
89-
store1.add_named_data("key1", b"data1", None, None)
105+
tensor_layout1 = TensorLayout(ScalarType.FLOAT, [1, 2], [0, 1])
106+
store1.add_named_data("key1", b"data1", None, None, tensor_layout1)
90107
store1.add_named_data("key2", b"data2", 16, "file1")
91108

92109
# Check items in the store1.
@@ -97,7 +114,7 @@ def test_merge(self) -> None:
97114
self.assertEqual(len(output.external_data["file1"]), 1)
98115

99116
store2 = NamedDataStore()
100-
store2.add_named_data("key1", b"data1", None, None)
117+
store2.add_named_data("key1", b"data1", None, None, tensor_layout1)
101118
store2.add_named_data("key3", b"data3", None, None)
102119
store2.add_named_data("key4", b"data4", 16, "file1")
103120
store2.add_named_data("key5", b"data5", 16, "file2")
@@ -118,6 +135,8 @@ def test_merge(self) -> None:
118135
# key1, data1 exist in both store1 and store2, so we only have one copy of it.
119136
self.assertEqual(len(output.buffers), 5)
120137
self.assertEqual(len(output.pte_data), 2)
138+
# Confirm DataEntry is correct.
139+
self.assertEqual(output.pte_data["key1"], DataEntry(0, 1, tensor_layout1))
121140
self.assertEqual(len(output.external_data), 2)
122141
self.assertEqual(len(output.external_data["file1"]), 2)
123142
self.assertEqual(len(output.external_data["file2"]), 1)

0 commit comments

Comments
 (0)