Skip to content

Commit 01e9d97

Browse files
committed
Reuse types in _named_data_store and support tensor layouts
Reuse `DataEntry` from data_serializer.py, in _named_data_store.py. Motivation - deserialize from flat tensor to named data store output - support tensor layout in named data store Differential Revision: [D83490345](https://our.internmc.facebook.com/intern/diff/D83490345/) ghstack-source-id: 312807877 Pull Request resolved: #14667
1 parent 782275f commit 01e9d97

File tree

5 files changed

+137
-121
lines changed

5 files changed

+137
-121
lines changed

exir/_serialize/_named_data_store.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,32 @@
77
# pyre-strict
88

99
import hashlib
10-
import math
1110
from dataclasses import dataclass
1211

1312
# from dataclasses import dataclass
14-
from typing import Dict, List, Optional
13+
from typing import Dict, Optional, Sequence
1514

16-
17-
@dataclass
18-
class BufferEntry:
19-
"""A class to hold the buffer entries for serialization.
20-
21-
Attributes:
22-
buffer: The buffer bytes.
23-
alignment: The alignment of the buffer.
24-
"""
25-
26-
buffer: bytes
27-
alignment: int
15+
from executorch.exir._serialize.data_serializer import DataEntry
16+
from executorch.exir.tensor_layout import TensorLayout
2817

2918

3019
@dataclass
3120
class NamedDataStoreOutput:
3221
"""
33-
Holds named data for serialization.
22+
Holds named data for serialization. Note: a DataEntry contains the index into
23+
`buffers`, the alignment and a tensor layout, if applicable.
3424
3525
Attributes:
3626
buffers: A list of unique buffer entries.
3727
pte_data: Contains data that is stored inside the PTE file. A mapping from
38-
{key: buffer_index}.
28+
{key: DataEntry}.
3929
external_data: Contains data that is stored external to the PTE. A mapping
40-
from {filename: {key: buffer_index}}.
30+
from {filename: {key: DataEntry}}.
4131
"""
4232

43-
buffers: List[BufferEntry]
44-
pte_data: Dict[str, int]
45-
external_data: Dict[str, Dict[str, int]]
33+
buffers: Sequence[bytes]
34+
pte_data: Dict[str, DataEntry]
35+
external_data: Dict[str, Dict[str, DataEntry]]
4636

4737

4838
class NamedDataStore:
@@ -61,12 +51,12 @@ class NamedDataStore:
6151
"""
6252

6353
# List of unique blobs.
64-
buffers: List[BufferEntry]
65-
# Named data stored inside the PTE file. Map of {key: buffer_index}.
66-
pte_data: Dict[str, int]
54+
buffers: Sequence[bytes]
55+
# Named data stored inside the PTE file. Map of {key: DataEntry}.
56+
pte_data: Dict[str, DataEntry]
6757
# Named data stored outside of the PTE file.
68-
# Map of {filename: {key: buffer_index}}.
69-
external_data: Dict[str, Dict[str, int]]
58+
# Map of {filename: {key: DataEntry}}.
59+
external_data: Dict[str, Dict[str, DataEntry]]
7060

7161
# Cache of the data hash for deduplication.
7262
# Use a hash instead of the data as a key because a sha256 collision is
@@ -93,7 +83,8 @@ def _add_named_data_to_map(
9383
key: str,
9484
data: bytes,
9585
alignment: int,
96-
local_key_to_buffer_idx: Dict[str, int],
86+
local_key_to_buffer_idx: Dict[str, DataEntry],
87+
tensor_layout: Optional[TensorLayout] = None,
9788
) -> None:
9889
"""
9990
Add data to a map and update the alignment. Ensure that the key-data
@@ -116,33 +107,31 @@ def _add_named_data_to_map(
116107

117108
# Check if the key exists.
118109
buffer_idx = self.key_to_buffer_idx.get(key, -1)
119-
if buffer_idx != -1:
120-
# If the key exists, the corresponding data must be identical.
121-
if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx:
122-
raise ValueError(
123-
f"Duplicate key {key} with different data. "
124-
f"Existing data: {self.buffers[buffer_idx].buffer}. "
125-
f"New data: {data}."
126-
)
127-
self.buffers[buffer_idx].alignment = math.lcm(
128-
self.buffers[buffer_idx].alignment, alignment
110+
# If the key exists, the corresponding data must be identical.
111+
if (
112+
buffer_idx != -1
113+
and self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx
114+
):
115+
raise ValueError(
116+
f"Duplicate key {key} with different data. "
117+
f"Existing data: {self.buffers[buffer_idx]}. "
118+
f"New data: {data}."
129119
)
130120
else:
131121
# Key doesn't exist; check if the data exists.
132122
buffer_idx = self.data_hash_to_buffer_idx.get(hashed, -1)
133-
if buffer_idx != -1:
134-
# The data exists; update the alignment.
135-
self.buffers[buffer_idx].alignment = math.lcm(
136-
self.buffers[buffer_idx].alignment, alignment
137-
)
138-
else:
123+
if buffer_idx == -1:
139124
# The data doesn't exist; add it to the data store.
140125
buffer_idx = len(self.buffers)
141-
self.buffers.append(BufferEntry(data, alignment))
126+
self.buffers.append(data)
142127
self.data_hash_to_buffer_idx[hashed] = buffer_idx
143128

144129
# Add key to the map and the key cache.
145-
local_key_to_buffer_idx[key] = buffer_idx
130+
local_key_to_buffer_idx[key] = DataEntry(
131+
buffer_index=buffer_idx,
132+
alignment=alignment,
133+
tensor_layout=tensor_layout,
134+
)
146135
self.key_to_buffer_idx[key] = buffer_idx
147136

148137
def add_named_data(
@@ -151,6 +140,7 @@ def add_named_data(
151140
data: bytes,
152141
alignment: Optional[int] = 1,
153142
external_tag: Optional[str] = None,
143+
tensor_layout: Optional[TensorLayout] = None,
154144
) -> None:
155145
"""
156146
Adds a named blob to the NamedDataStore.
@@ -159,6 +149,7 @@ def add_named_data(
159149
data (bytes): Bytes being requested to be serialized.
160150
alignment (int): alignment for bytes to be serialized with.
161151
external (Optional[str]): the external filename that this data is saved to.
152+
tensor_layout (Optional[TensorLayout]): layout of the tensor, if applicable.
162153
Raises:
163154
ValueError: when the key exists in the store, and corresponding data
164155
is different.
@@ -171,10 +162,16 @@ def add_named_data(
171162
raise ValueError(f"Alignment must be greater than 0, received {alignment}.")
172163

173164
if external_tag is None:
174-
self._add_named_data_to_map(key, data, alignment, self.pte_data)
165+
self._add_named_data_to_map(
166+
key, data, alignment, self.pte_data, tensor_layout
167+
)
175168
else:
176169
self._add_named_data_to_map(
177-
key, data, alignment, self.external_data.setdefault(external_tag, {})
170+
key,
171+
data,
172+
alignment,
173+
self.external_data.setdefault(external_tag, {}),
174+
tensor_layout,
178175
)
179176

180177
def get_named_data_store_output(self) -> NamedDataStoreOutput:
@@ -192,19 +189,22 @@ def merge_named_data_store(self, other: NamedDataStoreOutput) -> None:
192189
data is different between them.
193190
"""
194191
# Merge the pte_data.
195-
for key, buffer_idx in other.pte_data.items():
192+
for key, data_entry in other.pte_data.items():
196193
self.add_named_data(
197194
key,
198-
other.buffers[buffer_idx].buffer,
199-
other.buffers[buffer_idx].alignment,
195+
other.buffers[data_entry.buffer_index],
196+
data_entry.alignment,
197+
external_tag=None,
198+
tensor_layout=data_entry.tensor_layout,
200199
)
201200

202201
# Merge the external_data.
203-
for filename, key_to_buffer_idx in other.external_data.items():
204-
for key, buffer_idx in key_to_buffer_idx.items():
202+
for filename, key_to_data_entry in other.external_data.items():
203+
for key, data_entry in key_to_data_entry.items():
205204
self.add_named_data(
206205
key,
207-
other.buffers[buffer_idx].buffer,
208-
other.buffers[buffer_idx].alignment,
206+
other.buffers[data_entry.buffer_index],
207+
data_entry.alignment,
209208
external_tag=filename,
209+
tensor_layout=data_entry.tensor_layout,
210210
)

exir/_serialize/_program.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import re
1313

1414
from dataclasses import dataclass
15-
from typing import ClassVar, Dict, List, Literal, Optional, Tuple
15+
from typing import ClassVar, Dict, List, Literal, Optional, Sequence, Tuple
1616

1717
from executorch.exir._serialize._cord import Cord
1818
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
@@ -21,10 +21,9 @@
2121
_program_flatbuffer_to_json,
2222
_program_json_to_flatbuffer,
2323
)
24-
from executorch.exir._serialize._named_data_store import (
25-
BufferEntry,
26-
NamedDataStoreOutput,
27-
)
24+
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
25+
26+
from executorch.exir._serialize.data_serializer import DataEntry
2827

2928
from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
3029

@@ -368,8 +367,8 @@ def _extract_constant_segment(
368367
def _extract_named_data(
369368
program: Program,
370369
segments: List[AlignedData],
371-
buffers: List[BufferEntry],
372-
name_to_buffer_idx: Dict[str, int],
370+
buffers: Sequence[bytes],
371+
name_to_data_entry: Dict[str, DataEntry],
373372
) -> None:
374373
"""Modifies the program in-place to add references to the named data
375374
segments.
@@ -379,7 +378,7 @@ def _extract_named_data(
379378
segments: A list of buffers to append extracted segments to. Modified in-place.
380379
buffers: A list of unique buffers and the information required to
381380
serialize them. Not modified.
382-
name_to_buffer_idx: A map from the name of a blob to the index in buffers.
381+
name_to_data_entry: A map from the blob name to DataEntry.
383382
Not modified.
384383
"""
385384
if program.named_data is not None and len(program.named_data) > 0:
@@ -389,14 +388,14 @@ def _extract_named_data(
389388
segment_index_map: Dict[int, int] = {}
390389

391390
named_data: List[NamedData] = []
392-
for name, buffer_idx in name_to_buffer_idx.items():
393-
segment_index = segment_index_map.get(buffer_idx, None)
391+
for name, data_entry in name_to_data_entry.items():
392+
segment_index = segment_index_map.get(data_entry.buffer_index, None)
394393
if segment_index is None:
395394
segment_index = len(segments)
396-
segment_index_map[buffer_idx] = segment_index
395+
segment_index_map[data_entry.buffer_index] = segment_index
397396
segments.append(
398397
AlignedData(
399-
Cord(buffers[buffer_idx].buffer), buffers[buffer_idx].alignment
398+
Cord(buffers[data_entry.buffer_index]), data_entry.alignment
400399
)
401400
)
402401
named_data.append(NamedData(key=name, segment_index=segment_index))

exir/_serialize/_serialize.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,18 @@ def serialize_for_executorch(
117117
)
118118
buffers.append(emitter_output.external_constant_buffer[index])
119119

120-
# Extract external data.
120+
# Extract external data from named_data_store.
121121
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
122-
key_to_buffer_index = named_data_store.external_data.get(tag, {})
123-
for key, index in key_to_buffer_index.items():
122+
blob_to_data_entry = named_data_store.external_data.get(tag, {})
123+
for key, data_entry in blob_to_data_entry.items():
124124
assert key not in key_to_data_entry # key must be unique
125125
key_to_data_entry[key] = DataEntry(
126126
buffer_index=len(buffers),
127-
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
128-
alignment=named_data_store.buffers[index].alignment,
129-
tensor_layout=None,
127+
alignment=data_entry.alignment,
128+
tensor_layout=data_entry.tensor_layout,
130129
)
131-
buffers.append(named_data_store.buffers[index].buffer)
130+
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
131+
buffers.append(named_data_store.buffers[data_entry.buffer_index])
132132

133133
# Serialize into PTD file.
134134
ptd_files[tag] = data_serializer.serialize(

0 commit comments

Comments
 (0)