Skip to content
104 changes: 52 additions & 52 deletions exir/_serialize/_named_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,32 @@
# pyre-strict

import hashlib
import math
from dataclasses import dataclass

# from dataclasses import dataclass
from typing import Dict, List, Optional


@dataclass
class BufferEntry:
"""A class to hold the buffer entries for serialization.

Attributes:
buffer: The buffer bytes.
alignment: The alignment of the buffer.
"""

buffer: bytes
alignment: int
from executorch.exir._serialize.data_serializer import DataEntry
from executorch.exir.tensor_layout import TensorLayout


@dataclass
class NamedDataStoreOutput:
"""
Holds named data for serialization.
Holds named data for serialization. Note: a DataEntry contains the index into
`buffers`, the alignment and a tensor layout, if applicable.

Attributes:
buffers: A list of unique buffer entries.
pte_data: Contains data that is stored inside the PTE file. A mapping from
{key: buffer_index}.
{key: DataEntry}.
external_data: Contains data that is stored external to the PTE. A mapping
from {filename: {key: buffer_index}}.
from {filename: {key: DataEntry}}.
"""

buffers: List[BufferEntry]
pte_data: Dict[str, int]
external_data: Dict[str, Dict[str, int]]
buffers: List[bytes]
pte_data: Dict[str, DataEntry]
external_data: Dict[str, Dict[str, DataEntry]]


class NamedDataStore:
Expand All @@ -61,12 +51,12 @@ class NamedDataStore:
"""

# List of unique blobs.
buffers: List[BufferEntry]
# Named data stored inside the PTE file. Map of {key: buffer_index}.
pte_data: Dict[str, int]
buffers: List[bytes]
# Named data stored inside the PTE file. Map of {key: DataEntry}.
pte_data: Dict[str, DataEntry]
# Named data stored outside of the PTE file.
# Map of {filename: {key: buffer_index}}.
external_data: Dict[str, Dict[str, int]]
# Map of {filename: {key: DataEntry}}.
external_data: Dict[str, Dict[str, DataEntry]]

# Cache of the data hash for deduplication.
# Use a hash instead of the data as a key because a sha256 collision is
Expand All @@ -93,7 +83,8 @@ def _add_named_data_to_map(
key: str,
data: bytes,
alignment: int,
local_key_to_buffer_idx: Dict[str, int],
local_key_to_buffer_idx: Dict[str, DataEntry],
tensor_layout: Optional[TensorLayout] = None,
) -> None:
"""
Add data to a map and update the alignment. Ensure that the key-data
Expand All @@ -116,33 +107,31 @@ def _add_named_data_to_map(

# Check if the key exists.
buffer_idx = self.key_to_buffer_idx.get(key, -1)
if buffer_idx != -1:
# If the key exists, the corresponding data must be identical.
if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx:
raise ValueError(
f"Duplicate key {key} with different data. "
f"Existing data: {self.buffers[buffer_idx].buffer}. "
f"New data: {data}."
)
self.buffers[buffer_idx].alignment = math.lcm(
self.buffers[buffer_idx].alignment, alignment
# If the key exists, the corresponding data must be identical.
if (
buffer_idx != -1
and self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx
):
raise ValueError(
f"Duplicate key {key} with different data. "
f"Existing data: {self.buffers[buffer_idx]}. "
f"New data: {data}."
)
else:
# Key doesn't exist; check if the data exists.
buffer_idx = self.data_hash_to_buffer_idx.get(hashed, -1)
if buffer_idx != -1:
# The data exists; update the alignment.
self.buffers[buffer_idx].alignment = math.lcm(
self.buffers[buffer_idx].alignment, alignment
)
else:
if buffer_idx == -1:
# The data doesn't exist; add it to the data store.
buffer_idx = len(self.buffers)
self.buffers.append(BufferEntry(data, alignment))
self.buffers.append(data)
self.data_hash_to_buffer_idx[hashed] = buffer_idx

# Add key to the map and the key cache.
local_key_to_buffer_idx[key] = buffer_idx
local_key_to_buffer_idx[key] = DataEntry(
buffer_index=buffer_idx,
alignment=alignment,
tensor_layout=tensor_layout,
)
self.key_to_buffer_idx[key] = buffer_idx

def add_named_data(
Expand All @@ -151,6 +140,7 @@ def add_named_data(
data: bytes,
alignment: Optional[int] = 1,
external_tag: Optional[str] = None,
tensor_layout: Optional[TensorLayout] = None,
) -> None:
"""
Adds a named blob to the NamedDataStore.
Expand All @@ -159,6 +149,7 @@ def add_named_data(
data (bytes): Bytes being requested to be serialized.
alignment (int): alignment for bytes to be serialized with.
external (Optional[str]): the external filename that this data is saved to.
tensor_layout (Optional[TensorLayout]): layout of the tensor, if applicable.
Raises:
ValueError: when the key exists in the store, and corresponding data
is different.
Expand All @@ -171,10 +162,16 @@ def add_named_data(
raise ValueError(f"Alignment must be greater than 0, received {alignment}.")

if external_tag is None:
self._add_named_data_to_map(key, data, alignment, self.pte_data)
self._add_named_data_to_map(
key, data, alignment, self.pte_data, tensor_layout
)
else:
self._add_named_data_to_map(
key, data, alignment, self.external_data.setdefault(external_tag, {})
key,
data,
alignment,
self.external_data.setdefault(external_tag, {}),
tensor_layout,
)

def get_named_data_store_output(self) -> NamedDataStoreOutput:
Expand All @@ -192,19 +189,22 @@ def merge_named_data_store(self, other: NamedDataStoreOutput) -> None:
data is different between them.
"""
# Merge the pte_data.
for key, buffer_idx in other.pte_data.items():
for key, data_entry in other.pte_data.items():
self.add_named_data(
key,
other.buffers[buffer_idx].buffer,
other.buffers[buffer_idx].alignment,
other.buffers[data_entry.buffer_index],
data_entry.alignment,
external_tag=None,
tensor_layout=data_entry.tensor_layout,
)

# Merge the external_data.
for filename, key_to_buffer_idx in other.external_data.items():
for key, buffer_idx in key_to_buffer_idx.items():
for filename, key_to_data_entry in other.external_data.items():
for key, data_entry in key_to_data_entry.items():
self.add_named_data(
key,
other.buffers[buffer_idx].buffer,
other.buffers[buffer_idx].alignment,
other.buffers[data_entry.buffer_index],
data_entry.alignment,
external_tag=filename,
tensor_layout=data_entry.tensor_layout,
)
23 changes: 11 additions & 12 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re

from dataclasses import dataclass
from typing import ClassVar, Dict, List, Literal, Optional, Tuple
from typing import ClassVar, Dict, List, Literal, Optional, Sequence, Tuple

from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
Expand All @@ -21,10 +21,9 @@
_program_flatbuffer_to_json,
_program_json_to_flatbuffer,
)
from executorch.exir._serialize._named_data_store import (
BufferEntry,
NamedDataStoreOutput,
)
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput

from executorch.exir._serialize.data_serializer import DataEntry

from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required

Expand Down Expand Up @@ -368,8 +367,8 @@ def _extract_constant_segment(
def _extract_named_data(
program: Program,
segments: List[AlignedData],
buffers: List[BufferEntry],
name_to_buffer_idx: Dict[str, int],
buffers: Sequence[bytes],
name_to_data_entry: Dict[str, DataEntry],
) -> None:
"""Modifies the program in-place to add references to the named data
segments.
Expand All @@ -379,7 +378,7 @@ def _extract_named_data(
segments: A list of buffers to append extracted segments to. Modified in-place.
buffers: A list of unique buffers and the information required to
serialize them. Not modified.
name_to_buffer_idx: A map from the name of a blob to the index in buffers.
name_to_data_entry: A map from the blob name to DataEntry.
Not modified.
"""
if program.named_data is not None and len(program.named_data) > 0:
Expand All @@ -389,14 +388,14 @@ def _extract_named_data(
segment_index_map: Dict[int, int] = {}

named_data: List[NamedData] = []
for name, buffer_idx in name_to_buffer_idx.items():
segment_index = segment_index_map.get(buffer_idx, None)
for name, data_entry in name_to_data_entry.items():
segment_index = segment_index_map.get(data_entry.buffer_index, None)
if segment_index is None:
segment_index = len(segments)
segment_index_map[buffer_idx] = segment_index
segment_index_map[data_entry.buffer_index] = segment_index
segments.append(
AlignedData(
Cord(buffers[buffer_idx].buffer), buffers[buffer_idx].alignment
Cord(buffers[data_entry.buffer_index]), data_entry.alignment
)
)
named_data.append(NamedData(key=name, segment_index=segment_index))
Expand Down
14 changes: 7 additions & 7 deletions exir/_serialize/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,18 @@ def serialize_for_executorch(
)
buffers.append(emitter_output.external_constant_buffer[index])

# Extract external data.
# Extract external data from named_data_store.
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
key_to_buffer_index = named_data_store.external_data.get(tag, {})
for key, index in key_to_buffer_index.items():
blob_to_data_entry = named_data_store.external_data.get(tag, {})
for key, data_entry in blob_to_data_entry.items():
assert key not in key_to_data_entry # key must be unique
key_to_data_entry[key] = DataEntry(
buffer_index=len(buffers),
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
alignment=named_data_store.buffers[index].alignment,
tensor_layout=None,
alignment=data_entry.alignment,
tensor_layout=data_entry.tensor_layout,
)
buffers.append(named_data_store.buffers[index].buffer)
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
buffers.append(named_data_store.buffers[data_entry.buffer_index])

# Serialize into PTD file.
ptd_files[tag] = data_serializer.serialize(
Expand Down
53 changes: 36 additions & 17 deletions exir/_serialize/test/test_named_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

import unittest

from executorch.exir._serialize._named_data_store import BufferEntry, NamedDataStore
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir._serialize.data_serializer import DataEntry
from executorch.exir.scalar_type import ScalarType
from executorch.exir.tensor_layout import TensorLayout


class TestNamedDataStore(unittest.TestCase):
Expand All @@ -21,17 +24,17 @@ def test_add(self) -> None:
output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 3)
self.assertEqual(output.buffers[0], BufferEntry(b"data1", 1))
self.assertEqual(output.buffers[1], BufferEntry(b"data2", 16))
self.assertEqual(output.buffers[2], BufferEntry(b"data3", 16))
self.assertEqual(output.buffers[0], b"data1")
self.assertEqual(output.buffers[1], b"data2")
self.assertEqual(output.buffers[2], b"data3")

self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key1"], 0)
self.assertEqual(output.pte_data["key1"], DataEntry(0, 1, None))

self.assertEqual(len(output.external_data), 1)
self.assertEqual(len(output.external_data["file1"]), 2)
self.assertEqual(output.external_data["file1"]["key2"], 1)
self.assertEqual(output.external_data["file1"]["key3"], 2)
self.assertEqual(output.external_data["file1"]["key2"], DataEntry(1, 16, None))
self.assertEqual(output.external_data["file1"]["key3"], DataEntry(2, 16, None))

def test_add_duplicate_name_and_data(self) -> None:
store = NamedDataStore()
Expand All @@ -41,10 +44,10 @@ def test_add_duplicate_name_and_data(self) -> None:
output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 1)
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))
self.assertEqual(output.buffers[0], b"data")

self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key"], 0)
self.assertEqual(output.pte_data["key"], DataEntry(0, 1, None))

self.assertEqual(len(output.external_data), 0)

Expand All @@ -56,12 +59,11 @@ def test_add_same_data_with_different_alignment(self) -> None:
output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 1)
# Check that we take the LCM of the two alignments (3, 4) = 12
self.assertEqual(output.buffers[0], BufferEntry(b"data", 12))
self.assertEqual(output.buffers[0], b"data")

self.assertEqual(len(output.pte_data), 2)
self.assertEqual(output.pte_data["key"], 0)
self.assertEqual(output.pte_data["key1"], 0)
self.assertEqual(output.pte_data["key"], DataEntry(0, 3, None))
self.assertEqual(output.pte_data["key1"], DataEntry(0, 4, None))

self.assertEqual(len(output.external_data), 0)

Expand All @@ -78,15 +80,30 @@ def test_add_duplicate_key_fail(self) -> None:
output = store.get_named_data_store_output()

self.assertEqual(len(output.buffers), 1)
self.assertEqual(output.buffers[0], BufferEntry(b"data", 1))
self.assertEqual(output.buffers[0], b"data")

self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key"], 0)
self.assertEqual(output.pte_data["key"], DataEntry(0, 1, None))
self.assertEqual(len(output.external_data), 0)

def test_add_same_data_with_different_tensor_layout(self) -> None:
store = NamedDataStore()
tensor_layout1 = TensorLayout(ScalarType.FLOAT, [1, 2], [0, 1])
tensor_layout2 = TensorLayout(ScalarType.FLOAT, [2, 1], [0, 1])
store.add_named_data("key", b"data", None, None, tensor_layout1)
store.add_named_data("key1", b"data", None, None, tensor_layout2)

output = store.get_named_data_store_output()
self.assertEqual(len(output.buffers), 1)
self.assertEqual(output.buffers[0], b"data")

self.assertEqual(output.pte_data["key"], DataEntry(0, 1, tensor_layout1))
self.assertEqual(output.pte_data["key1"], DataEntry(0, 1, tensor_layout2))

def test_merge(self) -> None:
store1 = NamedDataStore()
store1.add_named_data("key1", b"data1", None, None)
tensor_layout1 = TensorLayout(ScalarType.FLOAT, [1, 2], [0, 1])
store1.add_named_data("key1", b"data1", None, None, tensor_layout1)
store1.add_named_data("key2", b"data2", 16, "file1")

# Check items in the store1.
Expand All @@ -97,7 +114,7 @@ def test_merge(self) -> None:
self.assertEqual(len(output.external_data["file1"]), 1)

store2 = NamedDataStore()
store2.add_named_data("key1", b"data1", None, None)
store2.add_named_data("key1", b"data1", None, None, tensor_layout1)
store2.add_named_data("key3", b"data3", None, None)
store2.add_named_data("key4", b"data4", 16, "file1")
store2.add_named_data("key5", b"data5", 16, "file2")
Expand All @@ -118,6 +135,8 @@ def test_merge(self) -> None:
# key1, data1 exist in both store1 and store2, so we only have one copy of it.
self.assertEqual(len(output.buffers), 5)
self.assertEqual(len(output.pte_data), 2)
# Confirm DataEntry is correct.
self.assertEqual(output.pte_data["key1"], DataEntry(0, 1, tensor_layout1))
self.assertEqual(len(output.external_data), 2)
self.assertEqual(len(output.external_data["file1"]), 2)
self.assertEqual(len(output.external_data["file2"]), 1)
Expand Down
Loading
Loading