diff --git a/exir/_serialize/TARGETS b/exir/_serialize/TARGETS index cc6f16d78d8..6671bf00334 100644 --- a/exir/_serialize/TARGETS +++ b/exir/_serialize/TARGETS @@ -32,6 +32,7 @@ runtime.python_library( "_cord.py", "_dataclass.py", "_flatbuffer.py", + "_named_data_store.py", "_program.py", "_serialize.py", "data_serializer.py", diff --git a/exir/_serialize/_named_data_store.py b/exir/_serialize/_named_data_store.py new file mode 100644 index 00000000000..999913a4bb0 --- /dev/null +++ b/exir/_serialize/_named_data_store.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import hashlib +import math +from dataclasses import dataclass + +# from dataclasses import dataclass +from typing import Dict, List, Optional + + +@dataclass +class BufferEntry: + """A class to hold the buffer entries for serialization. + + Attributes: + buffer: The buffer bytes. + alignment: The alignment of the buffer. + """ + + buffer: bytes + alignment: int + + +@dataclass +class NamedDataStoreOutput: + """ + Holds named data for serialization. + + Attributes: + buffers: A list of unique buffer entries. + pte_data: Contains data that is stored inside the PTE file. A mapping from + {key: buffer_index}. + external_data: Contains data that is stored external to the PTE. A mapping + from {filename: {key: buffer_index}}. + """ + + buffers: List[BufferEntry] + pte_data: Dict[str, int] + external_data: Dict[str, Dict[str, int]] + + +class NamedDataStore: + """ + NamedDataStore manages the data that delegates want to share. Backends add + bytes to the store under a unique key. These bytes can be retrieved at + runtime using the same key with the NamedDataMap. + + Note: + - Keys are unique in the data store, regardless of whether they are stored + in the PTE or externally. + - Multiple keys can point to the same buffer entry. + - The same data can be added multiple times and all keys will point to one + buffer. If a duplicate blob is added with a different alignment, the + lcm of the current and new alignment is taken for that blob. + """ + + # List of unique blobs. + buffers: List[BufferEntry] + # Named data stored inside the PTE file. Map of {key: buffer_index}. + pte_data: Dict[str, int] + # Named data stored outside of the PTE file. + # Map of {filename: {key: buffer_index}}. + external_data: Dict[str, Dict[str, int]] + + # Cache of the data hash for deduplication. + # Use a hash instead of the data as a key because a sha256 collision is + # unlikely, and the data may be large. + data_hash_to_buffer_idx: Dict[bytes, int] + # Cache of the key to buffer idx to ensure uniqueness. + # If a key is added multiple times, check the buffer idx to ensure that the + # data is identical too. + key_to_buffer_idx: Dict[str, int] + + def __init__(self) -> None: + """ + Initializes a new NamedDataStore. + """ + self.buffers = [] + self.pte_data = {} + self.external_data = {} + + self.data_hash_to_buffer_idx = {} + self.key_to_buffer_idx = {} + + def _add_named_data_to_map( + self, + key: str, + data: bytes, + alignment: int, + local_key_to_buffer_idx: Dict[str, int], + ) -> None: + """ + Add data to a map and update the alignment. Ensure that the key-data + pair is unique. + - If the key exists, the data must be identical. + - If multiple unique keys exist for the same data, those keys should + point to the same buffer. + + Args: + key (str): key associated with the data. + data (bytes): Bytes being requested to be serialized. + alignment (int): alignment for bytes to be serialized with. + local_key_to_buffer_idx (Dict[str, int]): map to add the data to. + Raises: + ValueError: when the key exists in the store, and corresponding data + is different. + """ + # Get data hash. + hashed = hashlib.sha256(data).digest() + + # Check if the key exists. + buffer_idx = self.key_to_buffer_idx.get(key, -1) + if buffer_idx != -1: + # If the key exists, the corresponding data must be identical. + if self.data_hash_to_buffer_idx.get(hashed, -1) != buffer_idx: + raise ValueError( + f"Duplicate key {key} with different data. " + f"Existing data: {self.buffers[buffer_idx].buffer}. " + f"New data: {data}." + ) + self.buffers[buffer_idx].alignment = math.lcm( + self.buffers[buffer_idx].alignment, alignment + ) + else: + # Key doesn't exist; check if the data exists. + buffer_idx = self.data_hash_to_buffer_idx.get(hashed, -1) + if buffer_idx != -1: + # The data exists; update the alignment. + self.buffers[buffer_idx].alignment = math.lcm( + self.buffers[buffer_idx].alignment, alignment + ) + else: + # The data doesn't exist; add it to the data store. + buffer_idx = len(self.buffers) + self.buffers.append(BufferEntry(data, alignment)) + self.data_hash_to_buffer_idx[hashed] = buffer_idx + + # Add key to the map and the key cache. + local_key_to_buffer_idx[key] = buffer_idx + self.key_to_buffer_idx[key] = buffer_idx + + def add_named_data( + self, + key: str, + data: bytes, + alignment: Optional[int] = 1, + external_tag: Optional[str] = None, + ) -> None: + """ + Adds a named blob to the NamedDataStore. + Args: + key (str): key associated with the data. + data (bytes): Bytes being requested to be serialized. + alignment (int): alignment for bytes to be serialized with. + external (Optional[str]): the external filename that this data is saved to. + Raises: + ValueError: when the key exists in the store, and corresponding data + is different. + """ + + # Set default alignment. + if alignment is None: + alignment = 1 + if alignment <= 0: + raise ValueError(f"Alignment must be greater than 0, received {alignment}.") + + if external_tag is None: + self._add_named_data_to_map(key, data, alignment, self.pte_data) + else: + self._add_named_data_to_map( + key, data, alignment, self.external_data.setdefault(external_tag, {}) + ) + + def get_named_data_store_output(self) -> NamedDataStoreOutput: + # Clean up empty maps inside self.external_data + self.external_data = {k: v for k, v in self.external_data.items() if len(v) > 0} + return NamedDataStoreOutput(self.buffers, self.pte_data, self.external_data) diff --git a/exir/_serialize/test/TARGETS b/exir/_serialize/test/TARGETS index 853d82b8a9a..63f47720137 100644 --- a/exir/_serialize/test/TARGETS +++ b/exir/_serialize/test/TARGETS @@ -3,7 +3,7 @@ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") oncall("executorch") python_unittest( - name = "program", + name = "test_program", srcs = [ "test_program.py", ], @@ -15,7 +15,7 @@ python_unittest( ) python_unittest( - name = "flatbuffer", + name = "test_flatbuffer", srcs = [ "test_flatbuffer.py", ], @@ -25,7 +25,7 @@ python_unittest( ) python_unittest( - name = "cord", + name = "test_cord", srcs = [ "test_cord.py", ], @@ -33,3 +33,13 @@ python_unittest( "//executorch/exir/_serialize:lib", ], ) + +python_unittest( + name = "test_named_data_store", + srcs = [ + "test_named_data_store.py", + ], + deps = [ + "//executorch/exir/_serialize:lib", + ], +) diff --git a/exir/_serialize/test/test_named_data_store.py b/exir/_serialize/test/test_named_data_store.py new file mode 100644 index 00000000000..d5355f6d7bf --- /dev/null +++ b/exir/_serialize/test/test_named_data_store.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +from executorch.exir._serialize._named_data_store import BufferEntry, NamedDataStore + + +class TestNamedDataStore(unittest.TestCase): + def test_add(self) -> None: + store = NamedDataStore() + store.add_named_data("key1", b"data1", None, None) + store.add_named_data("key2", b"data2", 16, "file1") + store.add_named_data("key3", b"data3", 16, "file1") + + output = store.get_named_data_store_output() + + self.assertEqual(len(output.buffers), 3) + self.assertEqual(output.buffers[0], BufferEntry(b"data1", 1)) + self.assertEqual(output.buffers[1], BufferEntry(b"data2", 16)) + self.assertEqual(output.buffers[2], BufferEntry(b"data3", 16)) + + self.assertEqual(len(output.pte_data), 1) + self.assertEqual(output.pte_data["key1"], 0) + + self.assertEqual(len(output.external_data), 1) + self.assertEqual(len(output.external_data["file1"]), 2) + self.assertEqual(output.external_data["file1"]["key2"], 1) + self.assertEqual(output.external_data["file1"]["key3"], 2) + + def test_add_duplicate_name_and_data(self) -> None: + store = NamedDataStore() + store.add_named_data("key", b"data", None, None) + store.add_named_data("key", b"data", None, None) + + output = store.get_named_data_store_output() + + self.assertEqual(len(output.buffers), 1) + self.assertEqual(output.buffers[0], BufferEntry(b"data", 1)) + + self.assertEqual(len(output.pte_data), 1) + self.assertEqual(output.pte_data["key"], 0) + + self.assertEqual(len(output.external_data), 0) + + def test_add_same_data_with_different_alignment(self) -> None: + store = NamedDataStore() + store.add_named_data("key", b"data", 3, None) + store.add_named_data("key1", b"data", 4, None) + + output = store.get_named_data_store_output() + + self.assertEqual(len(output.buffers), 1) + # Check that we take the LCM of the two alignments (3, 4) = 12 + self.assertEqual(output.buffers[0], BufferEntry(b"data", 12)) + + self.assertEqual(len(output.pte_data), 2) + self.assertEqual(output.pte_data["key"], 0) + self.assertEqual(output.pte_data["key1"], 0) + + self.assertEqual(len(output.external_data), 0) + + def test_add_duplicate_key_fail(self) -> None: + store = NamedDataStore() + store.add_named_data("key", b"data", None, None) + + # Cannot add item with the same key and different data. + self.assertRaises(ValueError, store.add_named_data, "key", b"data1", None, None) + self.assertRaises( + ValueError, store.add_named_data, "key", b"data1", 16, "file1" + ) + + output = store.get_named_data_store_output() + + self.assertEqual(len(output.buffers), 1) + self.assertEqual(output.buffers[0], BufferEntry(b"data", 1)) + + self.assertEqual(len(output.pte_data), 1) + self.assertEqual(output.pte_data["key"], 0) + self.assertEqual(len(output.external_data), 0)