From 0268d465ea14d1472fa6ec793e43daa5a6b6239d Mon Sep 17 00:00:00 2001 From: lucylq Date: Wed, 19 Feb 2025 14:20:45 -0800 Subject: [PATCH] Introduce NamedDataStore Introduce NamedDataStore for weight sharing. Rename 'NamedBlobStore' --> 'NamedDataStore' to mirror 'NamedDataMap' in the runtime. The NamedDataStore exposes two methods: - add_named_data: add a blob to the store - get_named_data_store_output: return the contents of the store, to pass to serialization. Invariants on the NamedDataStore - Keys are unique regardless of whether they are in PTE or external file. - Different keys can point to the same data. NamedDataStore is used in D69764150. It's owned by the EdgeProgramManager. Differential Revision: [D69764094](https://our.internmc.facebook.com/intern/diff/D69764094/) [ghstack-poisoned] --- exir/_serialize/TARGETS | 1 + exir/_serialize/_named_data_store.py | 178 ++++++++++++++++++ exir/_serialize/test/TARGETS | 10 + exir/_serialize/test/test_named_data_store.py | 85 +++++++++ 4 files changed, 274 insertions(+) create mode 100644 exir/_serialize/_named_data_store.py create mode 100644 exir/_serialize/test/test_named_data_store.py diff --git a/exir/_serialize/TARGETS b/exir/_serialize/TARGETS index cc6f16d78d8..6671bf00334 100644 --- a/exir/_serialize/TARGETS +++ b/exir/_serialize/TARGETS @@ -32,6 +32,7 @@ runtime.python_library( "_cord.py", "_dataclass.py", "_flatbuffer.py", + "_named_data_store.py", "_program.py", "_serialize.py", "data_serializer.py", diff --git a/exir/_serialize/_named_data_store.py b/exir/_serialize/_named_data_store.py new file mode 100644 index 00000000000..a3fc6fbd747 --- /dev/null +++ b/exir/_serialize/_named_data_store.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import hashlib +from dataclasses import dataclass + +# from dataclasses import dataclass +from typing import Dict, List, Optional + + +def gcd(a: int, b: int) -> int: + while b: + a, b = b, a % b + return a + + +def lcm(a: int, b: int) -> int: + return (a * b) // gcd(a, b) + + +@dataclass +class BufferEntry: + """A class to hold the buffer entries for serialization. + + Attributes: + buffer: The buffer bytes. + alignment: The alignment of the buffer. + """ + + buffer: bytes + alignment: int + + +@dataclass +class NamedDataStoreOutput: + """ + A class to hold the named data for serialization. + + Attributes: + buffer: A list of unique buffer entries. + pte_data: Contains data that is stored inside the PTE file. A mapping from + {key: buffer_index}. + external_data: Contains data that is stored external to the PTE. A mapping + from {filename: {key: buffer_index}}. + """ + + buffers: List[BufferEntry] + pte_data: Dict[str, int] + external_data: Dict[str, Dict[str, int]] + + +class NamedDataStore: + """ + NamedDataStore manages the data that delegates want to share. Backends add + bytes to the store under a unique key. These bytes can be retrieved at + runtime using the same key with the NamedDataMap. + + Note: + - Keys are unique in the data store, regardless of whether they are stored + in the PTE or externally. + - Multiple keys can point to the same buffer entry. + - The same data can be added multiple times; all keys will point to one + buffer. If a duplicate blob is added with a different alignment, the + lcm of the current and new alignment is taken for that blob. + """ + + # List of unique blobs. + buffers: List[BufferEntry] + # Named data stored inside the PTE file. Map of {key: buffer_index}. + pte_data: Dict[str, int] + # Named data stored outside of the PTE file. + # Map of {filename: {key: buffer_index}}. + external_data: Dict[str, Dict[str, int]] + + # Cache of the data hash for deduplication. + data_cache: Dict[str, int] + # Cache of the keys to ensure uniqueness. + key_cache: Dict[str, int] + + def __init__(self) -> None: + """ + Initializes a new NamedDataStore. + """ + self.buffers = [] + self.pte_data = {} + self.external_data = {} + + self.data_cache = {} + self.key_cache = {} + + def _add_named_data_to_map( + self, key: str, data: bytes, alignment: int, map: Dict[str, int] + ) -> None: + """ + Add data to a map and update the alignment. Ensure that the key-data + pair is unique. + - If the key exists, the data must be identical. + - If multiple unique keys exist for the same data, those keys should + point to the same buffer. + + Args: + key (str): key associated with the data. + data (bytes): Bytes being requested to be serialized. + alignment (int): alignment for bytes to be serialized with. + map (Dict[str, int]): map to add the data to. + Raises: + ValueError: when the key exists in the store, and corresponding data + is different. + """ + # Check if the key exists. + buffer_idx = self.key_cache.get(key, -1) + if buffer_idx != -1: + # If the key exists, the corresponding data must be identical. + if self.buffers[buffer_idx].buffer != data: + raise ValueError(f"Duplicate key {key} with different data.") + self.buffers[buffer_idx].alignment = lcm( + self.buffers[buffer_idx].alignment, alignment + ) + else: + # Key doesn't exist; check if the data exists. + hashed = hashlib.sha256(data).hexdigest() + buffer_idx = self.data_cache.get(hashed, -1) + if buffer_idx != -1: + # The data exists; update the alignment. + self.buffers[buffer_idx].alignment = lcm( + self.buffers[buffer_idx].alignment, alignment + ) + else: + # The data doesn't exist; add it to the data store. + buffer_idx = len(self.buffers) + self.buffers.append(BufferEntry(data, alignment)) + self.data_cache[hashed] = buffer_idx + + # Add key to the map and the key cache. + map[key] = buffer_idx + self.key_cache[key] = buffer_idx + + def add_named_data( + self, + key: str, + data: bytes, + alignment: Optional[int] = 1, + external_tag: Optional[str] = None, + ) -> None: + """ + Adds a named blob to the NamedDataStore. + Args: + key (str): key associated with the data. + data (bytes): Bytes being requested to be serialized. + alignment (int): alignment for bytes to be serialized with. + external (Optional[str]): the external filename that this data is saved to. + Raises: + ValueError: when the key exists in the store, and corresponding data + is different. + """ + + # Set default alignment. + if alignment is None: + alignment = 1 + + if external_tag is None: + self._add_named_data_to_map(key, data, alignment, self.pte_data) + else: + if self.external_data.get(external_tag, None) is None: + self.external_data[external_tag] = {} + self._add_named_data_to_map( + key, data, alignment, self.external_data[external_tag] + ) + + def get_named_data_store_output(self) -> NamedDataStoreOutput: + # Clean up empty maps inside self.external_data + self.external_data = {k: v for k, v in self.external_data.items() if len(v) > 0} + return NamedDataStoreOutput(self.buffers, self.pte_data, self.external_data) diff --git a/exir/_serialize/test/TARGETS b/exir/_serialize/test/TARGETS index 853d82b8a9a..6d17f3d680c 100644 --- a/exir/_serialize/test/TARGETS +++ b/exir/_serialize/test/TARGETS @@ -33,3 +33,13 @@ python_unittest( "//executorch/exir/_serialize:lib", ], ) + +python_unittest( + name = "named_data_store", + srcs = [ + "test_named_data_store.py", + ], + deps = [ + "//executorch/exir/_serialize:lib", + ], +) diff --git a/exir/_serialize/test/test_named_data_store.py b/exir/_serialize/test/test_named_data_store.py new file mode 100644 index 00000000000..d5355f6d7bf --- /dev/null +++ b/exir/_serialize/test/test_named_data_store.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +from executorch.exir._serialize._named_data_store import BufferEntry, NamedDataStore + + +class TestNamedDataStore(unittest.TestCase): + def test_add(self) -> None: + store = NamedDataStore() + store.add_named_data("key1", b"data1", None, None) + store.add_named_data("key2", b"data2", 16, "file1") + store.add_named_data("key3", b"data3", 16, "file1") + + output = store.get_named_data_store_output() + + self.assertEqual(len(output.buffers), 3) + self.assertEqual(output.buffers[0], BufferEntry(b"data1", 1)) + self.assertEqual(output.buffers[1], BufferEntry(b"data2", 16)) + self.assertEqual(output.buffers[2], BufferEntry(b"data3", 16)) + + self.assertEqual(len(output.pte_data), 1) + self.assertEqual(output.pte_data["key1"], 0) + + self.assertEqual(len(output.external_data), 1) + self.assertEqual(len(output.external_data["file1"]), 2) + self.assertEqual(output.external_data["file1"]["key2"], 1) + self.assertEqual(output.external_data["file1"]["key3"], 2) + + def test_add_duplicate_name_and_data(self) -> None: + store = NamedDataStore() + store.add_named_data("key", b"data", None, None) + store.add_named_data("key", b"data", None, None) + + output = store.get_named_data_store_output() + + self.assertEqual(len(output.buffers), 1) + self.assertEqual(output.buffers[0], BufferEntry(b"data", 1)) + + self.assertEqual(len(output.pte_data), 1) + self.assertEqual(output.pte_data["key"], 0) + + self.assertEqual(len(output.external_data), 0) + + def test_add_same_data_with_different_alignment(self) -> None: + store = NamedDataStore() + store.add_named_data("key", b"data", 3, None) + store.add_named_data("key1", b"data", 4, None) + + output = store.get_named_data_store_output() + + self.assertEqual(len(output.buffers), 1) + # Check that we take the LCM of the two alignments (3, 4) = 12 + self.assertEqual(output.buffers[0], BufferEntry(b"data", 12)) + + self.assertEqual(len(output.pte_data), 2) + self.assertEqual(output.pte_data["key"], 0) + self.assertEqual(output.pte_data["key1"], 0) + + self.assertEqual(len(output.external_data), 0) + + def test_add_duplicate_key_fail(self) -> None: + store = NamedDataStore() + store.add_named_data("key", b"data", None, None) + + # Cannot add item with the same key and different data. + self.assertRaises(ValueError, store.add_named_data, "key", b"data1", None, None) + self.assertRaises( + ValueError, store.add_named_data, "key", b"data1", 16, "file1" + ) + + output = store.get_named_data_store_output() + + self.assertEqual(len(output.buffers), 1) + self.assertEqual(output.buffers[0], BufferEntry(b"data", 1)) + + self.assertEqual(len(output.pte_data), 1) + self.assertEqual(output.pte_data["key"], 0) + self.assertEqual(len(output.external_data), 0)