Skip to content

Commit 5edbffd

Browse files
committed
Deserialize with named data store
As titled
1 parent f24351a commit 5edbffd

File tree

3 files changed

+80
-3
lines changed

3 files changed

+80
-3
lines changed

exir/_serialize/_dataclass.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _get_class_from_union(json_dict: Dict[str, Any], key: str, cls: Any) -> Any:
5757

5858

5959
# pyre-ignore
60-
def _json_to_dataclass(json_dict: Dict[str, Any], cls: Any = None) -> Any:
60+
def _json_to_dataclass(json_dict: Dict[str, Any], cls: Any = None) -> Any: # noqa: C901
6161
"""Initializes a dataclass given a dictionary loaded from a json,
6262
`json_dict`, and the expected class, `cls`, by iterating through the fields
6363
of the class and retrieving the data for each. If there is a field that is
@@ -139,7 +139,10 @@ class Example
139139
# If T is an enum then lookup the value in the enum otherwise try to
140140
# cast value to whatever type is required
141141
if isinstance(T, enum.EnumMeta):
142-
data[key] = T[value]
142+
if isinstance(value, str):
143+
data[key] = T[value]
144+
else:
145+
data[key] = T(value)
143146
else:
144147
data[key] = T(value)
145148
return cls(**data)

exir/_serialize/_program.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def serialize_pte_binary(
576576
return pte_data
577577

578578

579-
def _restore_segments(program: Program, segment_data: bytes) -> Program:
579+
def _restore_segments(program: Program, segment_data: bytes) -> Program: # noqa: C901
580580
"""Moves segments from `segment_data` into `program`.
581581
582582
This should recreate the original Program that the segments were extracted
@@ -641,6 +641,48 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
641641
program.constant_segment.segment_index = 0
642642
program.constant_segment.offsets = []
643643

644+
# Reconstruct named data blobs from segment data when present.
645+
if program.named_data:
646+
segment_to_buffer_index: Dict[int, int] = {}
647+
named_buffers: List[BufferEntry] = []
648+
key_to_buffer_index: Dict[str, int] = {}
649+
650+
for entry in program.named_data:
651+
segment_index = entry.segment_index
652+
if segment_index >= len(segments):
653+
raise ValueError(
654+
"Named data segment index "
655+
f"{segment_index} >= num segments {len(segments)}"
656+
)
657+
658+
buffer_index = segment_to_buffer_index.get(segment_index)
659+
if buffer_index is None:
660+
buffer_index = len(named_buffers)
661+
segment_to_buffer_index[segment_index] = buffer_index
662+
named_buffers.append(
663+
BufferEntry(buffer=segments[segment_index], alignment=1)
664+
)
665+
666+
key_to_buffer_index[entry.key] = buffer_index
667+
668+
named_data_store = NamedDataStoreOutput(
669+
buffers=named_buffers,
670+
pte_data=key_to_buffer_index,
671+
external_data={},
672+
)
673+
# Keep a convenient mapping from key to raw bytes for callers that only
674+
# need to read the blobs.
675+
setattr( # noqa: B010
676+
program,
677+
"named_data_blobs",
678+
{
679+
key: named_data_store.buffers[idx].buffer
680+
for key, idx in named_data_store.pte_data.items()
681+
},
682+
)
683+
setattr(program, "named_data_store", named_data_store) # noqa: B010
684+
program.named_data = []
685+
644686
# Clear out the segments list since the original Program didn't have one.
645687
program.segments = []
646688
return program

exir/_serialize/test/test_program.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,23 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
893893
self.assertEqual(program2.execution_plan, program.execution_plan)
894894
# Number of constant tensors should be the same.
895895
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))
896+
# Named data should be restored in a named data store and removed from the Program.
897+
self.assertEqual(program2.named_data, [])
898+
self.assertTrue(hasattr(program2, "named_data_store"))
899+
named_store = program2.named_data_store
900+
self.assertEqual(named_store.pte_data, pte_named_data)
901+
# Buffers in the restored store should match the original serialized blobs.
902+
restored = {
903+
key: named_store.buffers[idx].buffer
904+
for key, idx in named_store.pte_data.items()
905+
}
906+
original = {
907+
key: named_data_buffers[buf_idx].buffer
908+
for key, buf_idx in pte_named_data.items()
909+
}
910+
self.assertEqual(restored, original)
911+
self.assertTrue(hasattr(program2, "named_data_blobs"))
912+
self.assertEqual(program2.named_data_blobs, restored)
896913

897914
def test_named_data_segments(self) -> None:
898915
# Set segment alignment to 12 to test the padding.
@@ -997,6 +1014,21 @@ def test_named_data_segments(self) -> None:
9971014
buffers[2].buffer,
9981015
)
9991016

1017+
program2 = deserialize_pte_binary(pte_data)
1018+
self.assertEqual(program2.named_data, [])
1019+
self.assertTrue(hasattr(program2, "named_data_store"))
1020+
store = program2.named_data_store
1021+
self.assertEqual(store.pte_data, pte_named_data)
1022+
restored_named_data = {
1023+
key: store.buffers[idx].buffer for key, idx in store.pte_data.items()
1024+
}
1025+
self.assertEqual(
1026+
restored_named_data,
1027+
{key: buffers[idx].buffer for key, idx in pte_named_data.items()},
1028+
)
1029+
self.assertTrue(hasattr(program2, "named_data_blobs"))
1030+
self.assertEqual(program2.named_data_blobs, restored_named_data)
1031+
10001032

10011033
# Common data for extended header tests. The two example values should produce
10021034
# the example data.

0 commit comments

Comments
 (0)