From 7a33ea8d0070dda088ca012d8aeda7e1be2c0c86 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Tue, 17 Jun 2025 21:33:52 -0700 Subject: [PATCH] Implement python deserialize for flat_tensor --- extension/flat_tensor/serialize/serialize.py | 58 +++++++++++++++++++- extension/flat_tensor/test/test_serialize.py | 40 +++++++++++++- 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/extension/flat_tensor/serialize/serialize.py b/extension/flat_tensor/serialize/serialize.py index 5b29d7ccacd..7f3332c4303 100644 --- a/extension/flat_tensor/serialize/serialize.py +++ b/extension/flat_tensor/serialize/serialize.py @@ -19,7 +19,11 @@ from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile from executorch.exir._serialize._program import _insert_flatbuffer_header -from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer +from executorch.exir._serialize.data_serializer import ( + DataEntry, + DataPayload, + DataSerializer, +) from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required @@ -34,6 +38,9 @@ # endian. _HEADER_BYTEORDER: Literal["little"] = "little" +# Current version. Keep in sync with c++ version number in serialize. +_FLAT_TENSOR_VERSION: int = 0 + def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: """Serializes a FlatTensor to a flatbuffer and returns the serialized data.""" @@ -320,7 +327,7 @@ def serialize( # Create FlatTensor, which describes of the contents of the file and # points to all the data segments. It will be serialized to flatbuffer. flat_tensor = FlatTensor( - version=0, # Keep in sync with c++ version number in serialize.h + version=_FLAT_TENSOR_VERSION, segments=data_segments, named_data=named_data, ) @@ -383,4 +390,49 @@ def deserialize(self, blob: Cord) -> DataPayload: """ Deserializes a flat_tensor blob into a list of tensor metadata and tensors. """ - raise NotImplementedError("deserialize_data") + + data = bytes(blob) + + # Read header. Verify that it's valid. + header = FlatTensorHeader.from_bytes(data[8:]) + if not header.is_valid(): + raise RuntimeError( + "Flat tensor header is invalid. File is likely incorrect format or corrupt." + ) + + # Deserialize the flat tensor data, which contains the data offsets and tensor metadata. + flat_tensor_bytes = data[0 : header.flatbuffer_offset + header.flatbuffer_size] + flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes) + + # Verify that this is a supported version. + if flat_tensor.version != _FLAT_TENSOR_VERSION: + raise NotImplementedError( + f"Flat tensor files reports unsupported version {flat_tensor.version}. Expected {_FLAT_TENSOR_VERSION}." + ) + + # Extract the buffers. + buffers = [ + data[ + header.segment_base_offset + + segment.offset : header.segment_base_offset + + segment.offset + + segment.size + ] + for segment in flat_tensor.segments + ] + + payload = DataPayload( + buffers=buffers, + named_data={}, + ) + + # Read the named data entries. + for named_data in flat_tensor.named_data: + entry = DataEntry( + buffer_index=named_data.segment_index, + alignment=1, + tensor_layout=named_data.tensor_layout, + ) + payload.named_data[named_data.key] = entry + + return payload diff --git a/extension/flat_tensor/test/test_serialize.py b/extension/flat_tensor/test/test_serialize.py index 80ee59ae974..13402e60a65 100644 --- a/extension/flat_tensor/test/test_serialize.py +++ b/extension/flat_tensor/test/test_serialize.py @@ -6,17 +6,19 @@ # pyre-unsafe +import dataclasses import math import unittest from typing import List, Optional +from executorch.exir._serialize._cord import Cord + from executorch.exir._serialize.data_serializer import ( DataEntry, DataPayload, DataSerializer, ) - from executorch.exir._serialize.padding import aligned_size from executorch.exir.schema import ScalarType @@ -223,3 +225,39 @@ def test_serialize(self) -> None: ) self.assertEqual(segments[2].offset + segments[2].size, len(segment_data)) + + def test_round_trip(self) -> None: + # Serialize and then deserialize the test payload. Make sure it's reconstructed + # properly. + config = FlatTensorConfig() + serializer: DataSerializer = FlatTensorSerializer(config) + + # Round trip the data. + serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD)) + deserialized_payload = serializer.deserialize(Cord(serialized_data)) + + # Validate the deserialized payload. Since alignment isn't serialized, we need to + # do this somewhat manually. + for i in range(len(deserialized_payload.buffers)): + self.assertEqual( + TEST_DATA_PAYLOAD.buffers[i], + deserialized_payload.buffers[i], + f"Buffer at index {i} does not match.", + ) + + self.assertEqual( + TEST_DATA_PAYLOAD.named_data.keys(), deserialized_payload.named_data.keys() + ) + + SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison. + for key in TEST_DATA_PAYLOAD.named_data.keys(): + reference = TEST_DATA_PAYLOAD.named_data[key] + actual = deserialized_payload.named_data[key] + + for field in dataclasses.fields(reference): + if field.name not in SKIP_FIELDS: + self.assertEqual( + getattr(reference, field.name), + getattr(actual, field.name), + f"Named data record {key}.{field.name} does not match.", + )