Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions exir/_serialize/_named_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,30 @@ def get_named_data_store_output(self) -> NamedDataStoreOutput:
# Clean up empty maps inside self.external_data
self.external_data = {k: v for k, v in self.external_data.items() if len(v) > 0}
return NamedDataStoreOutput(self.buffers, self.pte_data, self.external_data)

def merge_named_data_store(self, other: NamedDataStoreOutput) -> None:
"""
Merge another NamedDataStore into this one.
Args:
other (NamedDataStore): the other NamedDataStore to merge.
Raises:
ValueError: when the key exists in both stores, and corresponding
data is different between them.
"""
# Merge the pte_data.
for key, buffer_idx in other.pte_data.items():
self.add_named_data(
key,
other.buffers[buffer_idx].buffer,
other.buffers[buffer_idx].alignment,
)

# Merge the external_data.
for filename, key_to_buffer_idx in other.external_data.items():
for key, buffer_idx in key_to_buffer_idx.items():
self.add_named_data(
key,
other.buffers[buffer_idx].buffer,
other.buffers[buffer_idx].alignment,
external_tag=filename,
)
59 changes: 59 additions & 0 deletions exir/_serialize/test/test_named_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,62 @@ def test_add_duplicate_key_fail(self) -> None:
self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key"], 0)
self.assertEqual(len(output.external_data), 0)

def test_merge(self) -> None:
store1 = NamedDataStore()
store1.add_named_data("key1", b"data1", None, None)
store1.add_named_data("key2", b"data2", 16, "file1")

# Check items in the store1.
output = store1.get_named_data_store_output()
self.assertEqual(len(output.buffers), 2)
self.assertEqual(len(output.pte_data), 1)
self.assertEqual(len(output.external_data), 1)
self.assertEqual(len(output.external_data["file1"]), 1)

store2 = NamedDataStore()
store2.add_named_data("key1", b"data1", None, None)
store2.add_named_data("key3", b"data3", None, None)
store2.add_named_data("key4", b"data4", 16, "file1")
store2.add_named_data("key5", b"data5", 16, "file2")

# Check items in store2.
output2 = store2.get_named_data_store_output()
self.assertEqual(len(output2.buffers), 4)
self.assertEqual(len(output2.pte_data), 2)
self.assertEqual(len(output2.external_data), 2)
self.assertEqual(len(output2.external_data["file1"]), 1)
self.assertEqual(len(output2.external_data["file2"]), 1)

# Merge store2 into store1.
store1.merge_named_data_store(output2)

# Check items in store2 are merged into store1.
output = store1.get_named_data_store_output()
# key1, data1 exist in both store1 and store2, so we only have one copy of it.
self.assertEqual(len(output.buffers), 5)
self.assertEqual(len(output.pte_data), 2)
self.assertEqual(len(output.external_data), 2)
self.assertEqual(len(output.external_data["file1"]), 2)
self.assertEqual(len(output.external_data["file2"]), 1)

def test_merge_duplicate_error(self) -> None:
store1 = NamedDataStore()
store1.add_named_data("key1", b"data1", None, None)

# Check items in the store1.
output = store1.get_named_data_store_output()
self.assertEqual(len(output.buffers), 1)
self.assertEqual(len(output.pte_data), 1)

store2 = NamedDataStore()
store2.add_named_data("key1", b"data2", None, None)

# Check items in store2.
output2 = store2.get_named_data_store_output()
self.assertEqual(len(output2.buffers), 1)
self.assertEqual(len(output2.pte_data), 1)

# Merge store2 into store1 raises error as key1 is already in store1
# with different data.
self.assertRaises(ValueError, store1.merge_named_data_store, output2)
Loading