Skip to content

Commit 9d3c942

Browse files
committed
deserialize named data
1 parent 56659e4 commit 9d3c942

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

exir/_serialize/_program.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def from_bytes(data: bytes) -> "_ExtendedHeader":
197197

198198
magic = data[0:4]
199199
length = int.from_bytes(data[4:8], byteorder=_HEADER_BYTEORDER)
200+
print(f"length: {length}")
200201
program_size = int.from_bytes(data[8:16], byteorder=_HEADER_BYTEORDER)
201202
segment_base_offset = int.from_bytes(data[16:24], byteorder=_HEADER_BYTEORDER)
202203
segment_data_size = (
@@ -226,6 +227,7 @@ def to_bytes(self) -> bytes:
226227
Note that this will ignore self.magic and self.length and will always
227228
write the proper magic/length.
228229
"""
230+
print(f"to_bytes: length: {self.EXPECTED_LENGTH}")
229231
data: bytes = (
230232
# Extended header magic. This lets consumers detect whether the
231233
# header was inserted or not. Always use the proper magic value
@@ -576,7 +578,7 @@ def serialize_pte_binary(
576578
return pte_data
577579

578580

579-
def _restore_segments(program: Program, segment_data: bytes) -> Program:
581+
def _restore_segments(program: Program, segment_data: bytes) -> Tuple[Program, NamedDataStoreOutput]:
580582
"""Moves segments from `segment_data` into `program`.
581583
582584
This should recreate the original Program that the segments were extracted
@@ -640,13 +642,28 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
640642
program.constant_buffer = buffers
641643
program.constant_segment.segment_index = 0
642644
program.constant_segment.offsets = []
645+
646+
buffers: List[BufferEntry] = []
647+
pte_data: Dict[str, int] = {}
648+
if program.named_data is not None and len(program.named_data) > 0:
649+
for i in range(len(program.named_data)):
650+
print(f"named data: {i} {program.named_data[i].key}")
651+
segment_index = program.named_data[i].segment_index
652+
if segment_index >= len(segments):
653+
raise ValueError(
654+
f"Named data {i} segment index {segment_index} >= num segments {len(segments)}"
655+
)
656+
buffers.append(BufferEntry(buffer=segments[segment_index], alignment=16))
657+
pte_data[program.named_data[i].key] = len(buffers) - 1
658+
program.named_data = []
643659

660+
named_data_store_output = NamedDataStoreOutput(buffers, pte_data, {})
644661
# Clear out the segments list since the original Program didn't have one.
645662
program.segments = []
646-
return program
663+
return program, named_data_store_output
647664

648665

649-
def deserialize_pte_binary(program_data: bytes) -> Program:
666+
def deserialize_pte_binary(program_data: bytes) -> Tuple[Program, NamedDataStoreOutput]:
650667
"""Returns a Program deserialized from the given runtime binary data."""
651668
program_size = len(program_data)
652669
segment_base_offset = 0
@@ -663,10 +680,11 @@ def deserialize_pte_binary(program_data: bytes) -> Program:
663680
_program_flatbuffer_to_json(program_data[:program_size])
664681
)
665682

683+
named_data_store_output = NamedDataStoreOutput([], {}, {})
666684
if segment_base_offset != 0:
667685
# Move segment data back into the Program.
668-
program = _restore_segments(
686+
program, named_data_store_output = _restore_segments(
669687
program=program, segment_data=program_data[segment_base_offset:]
670688
)
671689

672-
return program
690+
return program, named_data_store_output

0 commit comments

Comments
 (0)