From 2c5ed9152a7ce01f9394d1387f9e39098f306a2a Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Fri, 28 Feb 2025 14:42:23 -0800 Subject: [PATCH] add merge function for NamedDataStore (#8850) Summary: Allow us to change the NamedDataStore model. Usage will now be: ``` named_data_store for backend in backends: bytes, named_data_store = backend.preprocess() named_data_store.merge(named_data_store) ``` Note: - named_data_store is held by edge program manager - preprocess returns the named data store as part of PreprocessResult Differential Revision: D70409078 --- exir/_serialize/_named_data_store.py | 27 +++++++++ exir/_serialize/test/test_named_data_store.py | 59 +++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/exir/_serialize/_named_data_store.py b/exir/_serialize/_named_data_store.py index 999913a4bb0..2c2d975937e 100644 --- a/exir/_serialize/_named_data_store.py +++ b/exir/_serialize/_named_data_store.py @@ -181,3 +181,30 @@ def get_named_data_store_output(self) -> NamedDataStoreOutput: # Clean up empty maps inside self.external_data self.external_data = {k: v for k, v in self.external_data.items() if len(v) > 0} return NamedDataStoreOutput(self.buffers, self.pte_data, self.external_data) + + def merge_named_data_store(self, other: NamedDataStoreOutput) -> None: + """ + Merge another NamedDataStore into this one. + Args: + other (NamedDataStore): the other NamedDataStore to merge. + Raises: + ValueError: when the key exists in both stores, and corresponding + data is different between them. + """ + # Merge the pte_data. + for key, buffer_idx in other.pte_data.items(): + self.add_named_data( + key, + other.buffers[buffer_idx].buffer, + other.buffers[buffer_idx].alignment, + ) + + # Merge the external_data. + for filename, key_to_buffer_idx in other.external_data.items(): + for key, buffer_idx in key_to_buffer_idx.items(): + self.add_named_data( + key, + other.buffers[buffer_idx].buffer, + other.buffers[buffer_idx].alignment, + external_tag=filename, + ) diff --git a/exir/_serialize/test/test_named_data_store.py b/exir/_serialize/test/test_named_data_store.py index d5355f6d7bf..ffe6f2ddce7 100644 --- a/exir/_serialize/test/test_named_data_store.py +++ b/exir/_serialize/test/test_named_data_store.py @@ -83,3 +83,62 @@ def test_add_duplicate_key_fail(self) -> None: self.assertEqual(len(output.pte_data), 1) self.assertEqual(output.pte_data["key"], 0) self.assertEqual(len(output.external_data), 0) + + def test_merge(self) -> None: + store1 = NamedDataStore() + store1.add_named_data("key1", b"data1", None, None) + store1.add_named_data("key2", b"data2", 16, "file1") + + # Check items in the store1. + output = store1.get_named_data_store_output() + self.assertEqual(len(output.buffers), 2) + self.assertEqual(len(output.pte_data), 1) + self.assertEqual(len(output.external_data), 1) + self.assertEqual(len(output.external_data["file1"]), 1) + + store2 = NamedDataStore() + store2.add_named_data("key1", b"data1", None, None) + store2.add_named_data("key3", b"data3", None, None) + store2.add_named_data("key4", b"data4", 16, "file1") + store2.add_named_data("key5", b"data5", 16, "file2") + + # Check items in store2. + output2 = store2.get_named_data_store_output() + self.assertEqual(len(output2.buffers), 4) + self.assertEqual(len(output2.pte_data), 2) + self.assertEqual(len(output2.external_data), 2) + self.assertEqual(len(output2.external_data["file1"]), 1) + self.assertEqual(len(output2.external_data["file2"]), 1) + + # Merge store2 into store1. + store1.merge_named_data_store(output2) + + # Check items in store2 are merged into store1. + output = store1.get_named_data_store_output() + # key1, data1 exist in both store1 and store2, so we only have one copy of it. + self.assertEqual(len(output.buffers), 5) + self.assertEqual(len(output.pte_data), 2) + self.assertEqual(len(output.external_data), 2) + self.assertEqual(len(output.external_data["file1"]), 2) + self.assertEqual(len(output.external_data["file2"]), 1) + + def test_merge_duplicate_error(self) -> None: + store1 = NamedDataStore() + store1.add_named_data("key1", b"data1", None, None) + + # Check items in the store1. + output = store1.get_named_data_store_output() + self.assertEqual(len(output.buffers), 1) + self.assertEqual(len(output.pte_data), 1) + + store2 = NamedDataStore() + store2.add_named_data("key1", b"data2", None, None) + + # Check items in store2. + output2 = store2.get_named_data_store_output() + self.assertEqual(len(output2.buffers), 1) + self.assertEqual(len(output2.pte_data), 1) + + # Merge store2 into store1 raises error as key1 is already in store1 + # with different data. + self.assertRaises(ValueError, store1.merge_named_data_store, output2)