diff --git a/exir/_serialize/_named_data_store.py b/exir/_serialize/_named_data_store.py index 2c2d975937e..66c7a2a9c7b 100644 --- a/exir/_serialize/_named_data_store.py +++ b/exir/_serialize/_named_data_store.py @@ -7,42 +7,32 @@ # pyre-strict import hashlib -import math from dataclasses import dataclass # from dataclasses import dataclass from typing import Dict, List, Optional - -@dataclass -class BufferEntry: - """A class to hold the buffer entries for serialization. - - Attributes: - buffer: The buffer bytes. - alignment: The alignment of the buffer. - """ - - buffer: bytes - alignment: int +from executorch.exir._serialize.data_serializer import DataEntry +from executorch.exir.tensor_layout import TensorLayout @dataclass class NamedDataStoreOutput: """ - Holds named data for serialization. + Holds named data for serialization. Note: a DataEntry contains the index into + `buffers`, the alignment and a tensor layout, if applicable. Attributes: buffers: A list of unique buffer entries. pte_data: Contains data that is stored inside the PTE file. A mapping from - {key: buffer_index}. + {key: DataEntry}. external_data: Contains data that is stored external to the PTE. A mapping - from {filename: {key: buffer_index}}. + from {filename: {key: DataEntry}}. """ - buffers: List[BufferEntry] - pte_data: Dict[str, int] - external_data: Dict[str, Dict[str, int]] + buffers: List[bytes] + pte_data: Dict[str, DataEntry] + external_data: Dict[str, Dict[str, DataEntry]] class NamedDataStore: @@ -61,12 +51,12 @@ class NamedDataStore: """ # List of unique blobs. - buffers: List[BufferEntry] - # Named data stored inside the PTE file. Map of {key: buffer_index}. - pte_data: Dict[str, int] + buffers: List[bytes] + # Named data stored inside the PTE file. Map of {key: DataEntry}. + pte_data: Dict[str, DataEntry] # Named data stored outside of the PTE file. - # Map of {filename: {key: buffer_index}}. - external_data: Dict[str, Dict[str, int]] + # Map of {filename: {key: DataEntry}}. + external_data: Dict[str, Dict[str, DataEntry]] # Cache of the data hash for deduplication. # Use a hash instead of the data as a key because a sha256 collision is @@ -93,7 +83,8 @@ def _add_named_data_to_map( key: str, data: bytes, alignment: int, - local_key_to_buffer_idx: Dict[str, int], + local_key_to_buffer_idx: Dict[str, DataEntry], + tensor_layout: Optional[TensorLayout] = None, ) -> None: """ Add data to a map and update the alignment. Ensure that the key-data @@ -116,33 +107,31 @@ def _add_named_data_to_map( # Check if the key exists. buffer_idx = self.key_to_buffer_idx.get(key, -1) - if buffer_idx != -1: - # If the key exists, the corresponding data must be identical. - if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx: - raise ValueError( - f"Duplicate key {key} with different data. " - f"Existing data: {self.buffers[buffer_idx].buffer}. " - f"New data: {data}." - ) - self.buffers[buffer_idx].alignment = math.lcm( - self.buffers[buffer_idx].alignment, alignment + # If the key exists, the corresponding data must be identical. + if ( + buffer_idx != -1 + and self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx + ): + raise ValueError( + f"Duplicate key {key} with different data. " + f"Existing data: {self.buffers[buffer_idx]}. " + f"New data: {data}." ) else: # Key doesn't exist; check if the data exists. buffer_idx = self.data_hash_to_buffer_idx.get(hashed, -1) - if buffer_idx != -1: - # The data exists; update the alignment. - self.buffers[buffer_idx].alignment = math.lcm( - self.buffers[buffer_idx].alignment, alignment - ) - else: + if buffer_idx == -1: # The data doesn't exist; add it to the data store. buffer_idx = len(self.buffers) - self.buffers.append(BufferEntry(data, alignment)) + self.buffers.append(data) self.data_hash_to_buffer_idx[hashed] = buffer_idx # Add key to the map and the key cache. - local_key_to_buffer_idx[key] = buffer_idx + local_key_to_buffer_idx[key] = DataEntry( + buffer_index=buffer_idx, + alignment=alignment, + tensor_layout=tensor_layout, + ) self.key_to_buffer_idx[key] = buffer_idx def add_named_data( @@ -151,6 +140,7 @@ def add_named_data( data: bytes, alignment: Optional[int] = 1, external_tag: Optional[str] = None, + tensor_layout: Optional[TensorLayout] = None, ) -> None: """ Adds a named blob to the NamedDataStore. @@ -159,6 +149,7 @@ def add_named_data( data (bytes): Bytes being requested to be serialized. alignment (int): alignment for bytes to be serialized with. external (Optional[str]): the external filename that this data is saved to. + tensor_layout (Optional[TensorLayout]): layout of the tensor, if applicable. Raises: ValueError: when the key exists in the store, and corresponding data is different. @@ -171,10 +162,16 @@ def add_named_data( raise ValueError(f"Alignment must be greater than 0, received {alignment}.") if external_tag is None: - self._add_named_data_to_map(key, data, alignment, self.pte_data) + self._add_named_data_to_map( + key, data, alignment, self.pte_data, tensor_layout + ) else: self._add_named_data_to_map( - key, data, alignment, self.external_data.setdefault(external_tag, {}) + key, + data, + alignment, + self.external_data.setdefault(external_tag, {}), + tensor_layout, ) def get_named_data_store_output(self) -> NamedDataStoreOutput: @@ -192,19 +189,22 @@ def merge_named_data_store(self, other: NamedDataStoreOutput) -> None: data is different between them. """ # Merge the pte_data. - for key, buffer_idx in other.pte_data.items(): + for key, data_entry in other.pte_data.items(): self.add_named_data( key, - other.buffers[buffer_idx].buffer, - other.buffers[buffer_idx].alignment, + other.buffers[data_entry.buffer_index], + data_entry.alignment, + external_tag=None, + tensor_layout=data_entry.tensor_layout, ) # Merge the external_data. - for filename, key_to_buffer_idx in other.external_data.items(): - for key, buffer_idx in key_to_buffer_idx.items(): + for filename, key_to_data_entry in other.external_data.items(): + for key, data_entry in key_to_data_entry.items(): self.add_named_data( key, - other.buffers[buffer_idx].buffer, - other.buffers[buffer_idx].alignment, + other.buffers[data_entry.buffer_index], + data_entry.alignment, external_tag=filename, + tensor_layout=data_entry.tensor_layout, ) diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index 35a452c22ed..bee5b3438b0 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -12,7 +12,7 @@ import re from dataclasses import dataclass -from typing import ClassVar, Dict, List, Literal, Optional, Tuple +from typing import ClassVar, Dict, List, Literal, Optional, Sequence, Tuple from executorch.exir._serialize._cord import Cord from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass @@ -21,10 +21,9 @@ _program_flatbuffer_to_json, _program_json_to_flatbuffer, ) -from executorch.exir._serialize._named_data_store import ( - BufferEntry, - NamedDataStoreOutput, -) +from executorch.exir._serialize._named_data_store import NamedDataStoreOutput + +from executorch.exir._serialize.data_serializer import DataEntry from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required @@ -368,8 +367,8 @@ def _extract_constant_segment( def _extract_named_data( program: Program, segments: List[AlignedData], - buffers: List[BufferEntry], - name_to_buffer_idx: Dict[str, int], + buffers: Sequence[bytes], + name_to_data_entry: Dict[str, DataEntry], ) -> None: """Modifies the program in-place to add references to the named data segments. @@ -379,7 +378,7 @@ def _extract_named_data( segments: A list of buffers to append extracted segments to. Modified in-place. buffers: A list of unique buffers and the information required to serialize them. Not modified. - name_to_buffer_idx: A map from the name of a blob to the index in buffers. + name_to_data_entry: A map from the blob name to DataEntry. Not modified. """ if program.named_data is not None and len(program.named_data) > 0: @@ -389,14 +388,14 @@ def _extract_named_data( segment_index_map: Dict[int, int] = {} named_data: List[NamedData] = [] - for name, buffer_idx in name_to_buffer_idx.items(): - segment_index = segment_index_map.get(buffer_idx, None) + for name, data_entry in name_to_data_entry.items(): + segment_index = segment_index_map.get(data_entry.buffer_index, None) if segment_index is None: segment_index = len(segments) - segment_index_map[buffer_idx] = segment_index + segment_index_map[data_entry.buffer_index] = segment_index segments.append( AlignedData( - Cord(buffers[buffer_idx].buffer), buffers[buffer_idx].alignment + Cord(buffers[data_entry.buffer_index]), data_entry.alignment ) ) named_data.append(NamedData(key=name, segment_index=segment_index)) diff --git a/exir/_serialize/_serialize.py b/exir/_serialize/_serialize.py index 06e81997654..789ae89b190 100644 --- a/exir/_serialize/_serialize.py +++ b/exir/_serialize/_serialize.py @@ -117,18 +117,18 @@ def serialize_for_executorch( ) buffers.append(emitter_output.external_constant_buffer[index]) - # Extract external data. + # Extract external data from named_data_store. # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`. - key_to_buffer_index = named_data_store.external_data.get(tag, {}) - for key, index in key_to_buffer_index.items(): + blob_to_data_entry = named_data_store.external_data.get(tag, {}) + for key, data_entry in blob_to_data_entry.items(): assert key not in key_to_data_entry # key must be unique key_to_data_entry[key] = DataEntry( buffer_index=len(buffers), - # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`. - alignment=named_data_store.buffers[index].alignment, - tensor_layout=None, + alignment=data_entry.alignment, + tensor_layout=data_entry.tensor_layout, ) - buffers.append(named_data_store.buffers[index].buffer) + # pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`. + buffers.append(named_data_store.buffers[data_entry.buffer_index]) # Serialize into PTD file. ptd_files[tag] = data_serializer.serialize( diff --git a/exir/_serialize/test/test_named_data_store.py b/exir/_serialize/test/test_named_data_store.py index ffe6f2ddce7..b4ccf2e2cdb 100644 --- a/exir/_serialize/test/test_named_data_store.py +++ b/exir/_serialize/test/test_named_data_store.py @@ -8,7 +8,10 @@ import unittest -from executorch.exir._serialize._named_data_store import BufferEntry, NamedDataStore +from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.exir._serialize.data_serializer import DataEntry +from executorch.exir.scalar_type import ScalarType +from executorch.exir.tensor_layout import TensorLayout class TestNamedDataStore(unittest.TestCase): @@ -21,17 +24,17 @@ def test_add(self) -> None: output = store.get_named_data_store_output() self.assertEqual(len(output.buffers), 3) - self.assertEqual(output.buffers[0], BufferEntry(b"data1", 1)) - self.assertEqual(output.buffers[1], BufferEntry(b"data2", 16)) - self.assertEqual(output.buffers[2], BufferEntry(b"data3", 16)) + self.assertEqual(output.buffers[0], b"data1") + self.assertEqual(output.buffers[1], b"data2") + self.assertEqual(output.buffers[2], b"data3") self.assertEqual(len(output.pte_data), 1) - self.assertEqual(output.pte_data["key1"], 0) + self.assertEqual(output.pte_data["key1"], DataEntry(0, 1, None)) self.assertEqual(len(output.external_data), 1) self.assertEqual(len(output.external_data["file1"]), 2) - self.assertEqual(output.external_data["file1"]["key2"], 1) - self.assertEqual(output.external_data["file1"]["key3"], 2) + self.assertEqual(output.external_data["file1"]["key2"], DataEntry(1, 16, None)) + self.assertEqual(output.external_data["file1"]["key3"], DataEntry(2, 16, None)) def test_add_duplicate_name_and_data(self) -> None: store = NamedDataStore() @@ -41,10 +44,10 @@ def test_add_duplicate_name_and_data(self) -> None: output = store.get_named_data_store_output() self.assertEqual(len(output.buffers), 1) - self.assertEqual(output.buffers[0], BufferEntry(b"data", 1)) + self.assertEqual(output.buffers[0], b"data") self.assertEqual(len(output.pte_data), 1) - self.assertEqual(output.pte_data["key"], 0) + self.assertEqual(output.pte_data["key"], DataEntry(0, 1, None)) self.assertEqual(len(output.external_data), 0) @@ -56,12 +59,11 @@ def test_add_same_data_with_different_alignment(self) -> None: output = store.get_named_data_store_output() self.assertEqual(len(output.buffers), 1) - # Check that we take the LCM of the two alignments (3, 4) = 12 - self.assertEqual(output.buffers[0], BufferEntry(b"data", 12)) + self.assertEqual(output.buffers[0], b"data") self.assertEqual(len(output.pte_data), 2) - self.assertEqual(output.pte_data["key"], 0) - self.assertEqual(output.pte_data["key1"], 0) + self.assertEqual(output.pte_data["key"], DataEntry(0, 3, None)) + self.assertEqual(output.pte_data["key1"], DataEntry(0, 4, None)) self.assertEqual(len(output.external_data), 0) @@ -78,15 +80,30 @@ def test_add_duplicate_key_fail(self) -> None: output = store.get_named_data_store_output() self.assertEqual(len(output.buffers), 1) - self.assertEqual(output.buffers[0], BufferEntry(b"data", 1)) + self.assertEqual(output.buffers[0], b"data") self.assertEqual(len(output.pte_data), 1) - self.assertEqual(output.pte_data["key"], 0) + self.assertEqual(output.pte_data["key"], DataEntry(0, 1, None)) self.assertEqual(len(output.external_data), 0) + def test_add_same_data_with_different_tensor_layout(self) -> None: + store = NamedDataStore() + tensor_layout1 = TensorLayout(ScalarType.FLOAT, [1, 2], [0, 1]) + tensor_layout2 = TensorLayout(ScalarType.FLOAT, [2, 1], [0, 1]) + store.add_named_data("key", b"data", None, None, tensor_layout1) + store.add_named_data("key1", b"data", None, None, tensor_layout2) + + output = store.get_named_data_store_output() + self.assertEqual(len(output.buffers), 1) + self.assertEqual(output.buffers[0], b"data") + + self.assertEqual(output.pte_data["key"], DataEntry(0, 1, tensor_layout1)) + self.assertEqual(output.pte_data["key1"], DataEntry(0, 1, tensor_layout2)) + def test_merge(self) -> None: store1 = NamedDataStore() - store1.add_named_data("key1", b"data1", None, None) + tensor_layout1 = TensorLayout(ScalarType.FLOAT, [1, 2], [0, 1]) + store1.add_named_data("key1", b"data1", None, None, tensor_layout1) store1.add_named_data("key2", b"data2", 16, "file1") # Check items in the store1. @@ -97,7 +114,7 @@ def test_merge(self) -> None: self.assertEqual(len(output.external_data["file1"]), 1) store2 = NamedDataStore() - store2.add_named_data("key1", b"data1", None, None) + store2.add_named_data("key1", b"data1", None, None, tensor_layout1) store2.add_named_data("key3", b"data3", None, None) store2.add_named_data("key4", b"data4", 16, "file1") store2.add_named_data("key5", b"data5", 16, "file2") @@ -118,6 +135,8 @@ def test_merge(self) -> None: # key1, data1 exist in both store1 and store2, so we only have one copy of it. self.assertEqual(len(output.buffers), 5) self.assertEqual(len(output.pte_data), 2) + # Confirm DataEntry is correct. + self.assertEqual(output.pte_data["key1"], DataEntry(0, 1, tensor_layout1)) self.assertEqual(len(output.external_data), 2) self.assertEqual(len(output.external_data["file1"]), 2) self.assertEqual(len(output.external_data["file2"]), 1) diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index 7ed83569169..80f4b8ca49f 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -16,10 +16,7 @@ from typing import List, Sequence from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json -from executorch.exir._serialize._named_data_store import ( - BufferEntry, - NamedDataStoreOutput, -) +from executorch.exir._serialize._named_data_store import NamedDataStoreOutput from executorch.exir._serialize._program import ( _ExtendedHeader, _get_extended_header, @@ -28,6 +25,7 @@ deserialize_pte_binary, serialize_pte_binary, ) +from executorch.exir._serialize.data_serializer import DataEntry from executorch.exir._serialize.padding import aligned_size from executorch.exir.schema import ( @@ -699,14 +697,14 @@ def test_constant_delegate_and_named_data_segments(self) -> None: # Create named data segment. named_data_buffers = [ - BufferEntry( - buffer=self.gen_blob_data(8, b"\x50\x55\x05"), alignment=3 - ), # expect lcm(3, 128) = 384 - BufferEntry( - buffer=self.gen_blob_data(16, b"\x60\x66\x06"), alignment=256 - ), # expect lcm(256, 128) = 256 + self.gen_blob_data(8, b"\x50\x55\x05"), + self.gen_blob_data(16, b"\x60\x66\x06"), ] - pte_named_data = {"key0": 0, "key1": 1} + buffer_alignment = [3, 256] + pte_named_data = { + "key0": DataEntry(0, buffer_alignment[0], None), # expect lcm(3, 128) = 384 + "key1": DataEntry(1, buffer_alignment[1], None), + } # expect lcm(256, 128) = 256 named_data = NamedDataStoreOutput( buffers=named_data_buffers, pte_data=pte_named_data, external_data={} ) @@ -762,16 +760,16 @@ def test_constant_delegate_and_named_data_segments(self) -> None: # Named data segments. expected_offset = aligned_size( (segment_table[2].offset + segment_table[2].size), - math.lcm(named_data_buffers[0].alignment, SEGMENT_ALIGNMENT), + math.lcm(buffer_alignment[0], SEGMENT_ALIGNMENT), ) self.assertEqual(segment_table[3].offset, expected_offset) - self.assertEqual(segment_table[3].size, len(named_data_buffers[0].buffer)) + self.assertEqual(segment_table[3].size, len(named_data_buffers[0])) expected_offset = aligned_size( (segment_table[3].offset + segment_table[3].size), - math.lcm(named_data_buffers[1].alignment, SEGMENT_ALIGNMENT), + math.lcm(buffer_alignment[1], SEGMENT_ALIGNMENT), ) self.assertEqual(segment_table[4].offset, expected_offset) - self.assertEqual(segment_table[4].size, len(named_data_buffers[1].buffer)) + self.assertEqual(segment_table[4].size, len(named_data_buffers[1])) # Named data. self.assertTrue(program_with_segments.named_data is not None) @@ -874,7 +872,7 @@ def test_constant_delegate_and_named_data_segments(self) -> None: segment_table[3].offset : segment_table[3].offset + segment_table[3].size ], - named_data_buffers[0].buffer, + named_data_buffers[0], ) self.assertEqual( @@ -882,7 +880,7 @@ def test_constant_delegate_and_named_data_segments(self) -> None: segment_table[4].offset : segment_table[4].offset + segment_table[4].size ], - named_data_buffers[1].buffer, + named_data_buffers[1], ) # Convert back. @@ -903,17 +901,17 @@ def test_named_data_segments(self) -> None: # Create named data segments with different alignments. buffers = [ - BufferEntry( - buffer=self.gen_blob_data(8, b"\x10\x11\x01"), alignment=8 - ), # expect lcm(8, 12) = 24 - BufferEntry( - buffer=self.gen_blob_data(16, b"\x20\x22\x02"), alignment=32 - ), # expect lcm(32, 12) = 96 - BufferEntry( - buffer=self.gen_blob_data(24, b"\x30\x33\x03"), alignment=24 - ), # expect lcm(24, 12) = 24 + self.gen_blob_data(8, b"\x10\x11\x01"), + self.gen_blob_data(16, b"\x20\x22\x02"), + self.gen_blob_data(24, b"\x30\x33\x03"), ] - pte_named_data = {"key1": 0, "key2": 0, "key3": 1, "key4": 2} + buffer_alignment = [8, 16, 24] + pte_named_data = { + "key1": DataEntry(0, buffer_alignment[0], None), # expect lcm(8, 12) = 24 + "key2": DataEntry(0, buffer_alignment[0], None), # expect lcm(8, 12) = 24 + "key3": DataEntry(1, buffer_alignment[1], None), # expect lcm(32, 12) = 96 + "key4": DataEntry(2, buffer_alignment[2], None), + } # expect lcm(24, 12) = 24 named_data = NamedDataStoreOutput( buffers=buffers, pte_data=pte_named_data, external_data={} ) @@ -965,10 +963,10 @@ def test_named_data_segments(self) -> None: segment_table[i - 1].offset + segment_table[i - 1].size if i > 0 else 0 ) expected_offset = aligned_size( - segment_length, math.lcm(SEGMENT_ALIGNMENT, buffers[i].alignment) + segment_length, math.lcm(SEGMENT_ALIGNMENT, buffer_alignment[i]) ) self.assertEqual(segment_table[i].offset, expected_offset) - self.assertEqual(segment_table[i].size, len(buffers[i].buffer)) + self.assertEqual(segment_table[i].size, len(buffers[i])) # Check the pte data for buffer values. segment_data: bytes = pte_data[eh.segment_base_offset :] @@ -980,21 +978,21 @@ def test_named_data_segments(self) -> None: segment_table[0].offset : segment_table[0].offset + segment_table[0].size ], - buffers[0].buffer, + buffers[0], ) self.assertEqual( segment_data[ segment_table[1].offset : segment_table[1].offset + segment_table[1].size ], - buffers[1].buffer, + buffers[1], ) self.assertEqual( segment_data[ segment_table[2].offset : segment_table[2].offset + segment_table[2].size ], - buffers[2].buffer, + buffers[2], )