|  | 
| 19 | 19 | 
 | 
| 20 | 20 | from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile | 
| 21 | 21 | from executorch.exir._serialize._program import _insert_flatbuffer_header | 
| 22 |  | -from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer | 
|  | 22 | +from executorch.exir._serialize.data_serializer import ( | 
|  | 23 | +    DataEntry, | 
|  | 24 | +    DataPayload, | 
|  | 25 | +    DataSerializer, | 
|  | 26 | +) | 
| 23 | 27 | 
 | 
| 24 | 28 | from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required | 
| 25 | 29 | 
 | 
|  | 
| 34 | 38 | # endian. | 
| 35 | 39 | _HEADER_BYTEORDER: Literal["little"] = "little" | 
| 36 | 40 | 
 | 
|  | 41 | +# Current version. Keep in sync with c++ version number in serialize. | 
|  | 42 | +_FLAT_TENSOR_VERSION: int = 0 | 
|  | 43 | + | 
| 37 | 44 | 
 | 
| 38 | 45 | def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: | 
| 39 | 46 |     """Serializes a FlatTensor to a flatbuffer and returns the serialized data.""" | 
| @@ -320,7 +327,7 @@ def serialize( | 
| 320 | 327 |         # Create FlatTensor, which describes of the contents of the file and | 
| 321 | 328 |         # points to all the data segments. It will be serialized to flatbuffer. | 
| 322 | 329 |         flat_tensor = FlatTensor( | 
| 323 |  | -            version=0,  # Keep in sync with c++ version number in serialize.h | 
|  | 330 | +            version=_FLAT_TENSOR_VERSION, | 
| 324 | 331 |             segments=data_segments, | 
| 325 | 332 |             named_data=named_data, | 
| 326 | 333 |         ) | 
| @@ -383,4 +390,49 @@ def deserialize(self, blob: Cord) -> DataPayload: | 
| 383 | 390 |         """ | 
| 384 | 391 |         Deserializes a flat_tensor blob into a list of tensor metadata and tensors. | 
| 385 | 392 |         """ | 
| 386 |  | -        raise NotImplementedError("deserialize_data") | 
|  | 393 | + | 
|  | 394 | +        data = bytes(blob) | 
|  | 395 | + | 
|  | 396 | +        # Read header. Verify that it's valid. | 
|  | 397 | +        header = FlatTensorHeader.from_bytes(data[8:]) | 
|  | 398 | +        if not header.is_valid(): | 
|  | 399 | +            raise RuntimeError( | 
|  | 400 | +                "Flat tensor header is invalid. File is likely incorrect format or corrupt." | 
|  | 401 | +            ) | 
|  | 402 | + | 
|  | 403 | +        # Deserialize the flat tensor data, which contains the data offsets and tensor metadata. | 
|  | 404 | +        flat_tensor_bytes = data[0 : header.flatbuffer_offset + header.flatbuffer_size] | 
|  | 405 | +        flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes) | 
|  | 406 | + | 
|  | 407 | +        # Verify that this is a supported version. | 
|  | 408 | +        if flat_tensor.version != _FLAT_TENSOR_VERSION: | 
|  | 409 | +            raise NotImplementedError( | 
|  | 410 | +                f"Flat tensor files reports unsupported version {flat_tensor.version}. Expected {_FLAT_TENSOR_VERSION}." | 
|  | 411 | +            ) | 
|  | 412 | + | 
|  | 413 | +        # Extract the buffers. | 
|  | 414 | +        buffers = [ | 
|  | 415 | +            data[ | 
|  | 416 | +                header.segment_base_offset | 
|  | 417 | +                + segment.offset : header.segment_base_offset | 
|  | 418 | +                + segment.offset | 
|  | 419 | +                + segment.size | 
|  | 420 | +            ] | 
|  | 421 | +            for segment in flat_tensor.segments | 
|  | 422 | +        ] | 
|  | 423 | + | 
|  | 424 | +        payload = DataPayload( | 
|  | 425 | +            buffers=buffers, | 
|  | 426 | +            named_data={}, | 
|  | 427 | +        ) | 
|  | 428 | + | 
|  | 429 | +        # Read the named data entries. | 
|  | 430 | +        for named_data in flat_tensor.named_data: | 
|  | 431 | +            entry = DataEntry( | 
|  | 432 | +                buffer_index=named_data.segment_index, | 
|  | 433 | +                alignment=1, | 
|  | 434 | +                tensor_layout=named_data.tensor_layout, | 
|  | 435 | +            ) | 
|  | 436 | +            payload.named_data[named_data.key] = entry | 
|  | 437 | + | 
|  | 438 | +        return payload | 
0 commit comments