Skip to content

Commit 1422bfe

Browse files
committed
Implement python deserialize for flat_tensor
1 parent 5960a4b commit 1422bfe

File tree

2 files changed

+79
-3
lines changed

2 files changed

+79
-3
lines changed

extension/flat_tensor/serialize/serialize.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
2121
from executorch.exir._serialize._program import _insert_flatbuffer_header
22-
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer
22+
from executorch.exir._serialize.data_serializer import DataEntry, DataPayload, DataSerializer
2323

2424
from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
2525

@@ -34,6 +34,9 @@
3434
# endian.
3535
_HEADER_BYTEORDER: Literal["little"] = "little"
3636

37+
# Current version. Keep in sync with c++ version number in serialize.
38+
_FLAT_TENSOR_VERSION: int = 0
39+
3740

3841
def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
3942
"""Serializes a FlatTensor to a flatbuffer and returns the serialized data."""
@@ -320,7 +323,7 @@ def serialize(
320323
# Create FlatTensor, which describes of the contents of the file and
321324
# points to all the data segments. It will be serialized to flatbuffer.
322325
flat_tensor = FlatTensor(
323-
version=0, # Keep in sync with c++ version number in serialize.h
326+
version=_FLAT_TENSOR_VERSION,
324327
segments=data_segments,
325328
named_data=named_data,
326329
)
@@ -383,4 +386,46 @@ def deserialize(self, blob: Cord) -> DataPayload:
383386
"""
384387
Deserializes a flat_tensor blob into a list of tensor metadata and tensors.
385388
"""
386-
raise NotImplementedError("deserialize_data")
389+
390+
data = bytes(blob)
391+
392+
# Read header. Verify that it's valid.
393+
header = FlatTensorHeader.from_bytes(data[8:])
394+
if not header.is_valid():
395+
raise RuntimeError("Flat tensor header is invalid. File is likely incorrect format or corrupt.")
396+
397+
# Deserialize the flat tensor data, which contains the data offsets and tensor metadata.
398+
flat_tensor_bytes = data[
399+
0 : header.flatbuffer_offset + header.flatbuffer_size
400+
]
401+
flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes)
402+
403+
# Verify that this is a supported version.
404+
if flat_tensor.version != _FLAT_TENSOR_VERSION:
405+
raise NotImplementedError(f"Flat tensor files reports unsupported version {flat_tensor.version}. Expected {_FLAT_TENSOR_VERSION}.")
406+
407+
# Extract the buffers.
408+
buffers = list(
409+
data[
410+
header.segment_base_offset + segment.offset :
411+
header.segment_base_offset + segment.offset + segment.size
412+
]
413+
for segment
414+
in flat_tensor.segments
415+
)
416+
417+
payload = DataPayload(
418+
buffers=buffers,
419+
named_data={},
420+
)
421+
422+
# Read the named data entries.
423+
for named_data in flat_tensor.named_data:
424+
entry = DataEntry(
425+
buffer_index = named_data.segment_index,
426+
alignment = 1,
427+
tensor_layout = named_data.tensor_layout,
428+
)
429+
payload.named_data[named_data.key] = entry
430+
431+
return payload

extension/flat_tensor/test/test_serialize.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-unsafe
88

9+
import dataclasses
910
import math
1011
import unittest
1112

@@ -17,6 +18,7 @@
1718
DataSerializer,
1819
)
1920

21+
from executorch.exir._serialize._cord import Cord
2022
from executorch.exir._serialize.padding import aligned_size
2123

2224
from executorch.exir.schema import ScalarType
@@ -223,3 +225,32 @@ def test_serialize(self) -> None:
223225
)
224226

225227
self.assertEqual(segments[2].offset + segments[2].size, len(segment_data))
228+
229+
def test_round_trip(self) -> None:
230+
# Serialize and then deserialize the test payload. Make sure it's reconstructed
231+
# properly.
232+
config = FlatTensorConfig()
233+
serializer: DataSerializer = FlatTensorSerializer(config)
234+
235+
# Round trip the data.
236+
serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD))
237+
deserialized_payload = serializer.deserialize(Cord(serialized_data))
238+
239+
# Validate the deserialized payload. Since alignment isn't serialized, we need to
240+
# do this somewhat manually.
241+
for i in range(len(deserialized_payload.buffers)):
242+
self.assertEqual(TEST_DATA_PAYLOAD.buffers[i], deserialized_payload.buffers[i], f"Buffer at index {i} does not match.")
243+
244+
self.assertEqual(TEST_DATA_PAYLOAD.named_data.keys(), deserialized_payload.named_data.keys())
245+
246+
SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison.
247+
for key in TEST_DATA_PAYLOAD.named_data.keys():
248+
reference = TEST_DATA_PAYLOAD.named_data[key]
249+
actual = deserialized_payload.named_data[key]
250+
251+
for field in dataclasses.fields(reference):
252+
if field.name not in SKIP_FIELDS:
253+
self.assertEqual(
254+
getattr(TEST_DATA_PAYLOAD.named_data[key], field.name),
255+
getattr(deserialized_payload.named_data[key], field.name),
256+
f"Named data record {key}.{field.name} does not match.")

0 commit comments

Comments
 (0)