|
19 | 19 |
|
20 | 20 | from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile |
21 | 21 | from executorch.exir._serialize._program import _insert_flatbuffer_header |
22 | | -from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer |
| 22 | +from executorch.exir._serialize.data_serializer import DataEntry, DataPayload, DataSerializer |
23 | 23 |
|
24 | 24 | from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required |
25 | 25 |
|
|
34 | 34 | # endian. |
35 | 35 | _HEADER_BYTEORDER: Literal["little"] = "little" |
36 | 36 |
|
| 37 | +# Current version. Keep in sync with c++ version number in serialize. |
| 38 | +_FLAT_TENSOR_VERSION: int = 0 |
| 39 | + |
37 | 40 |
|
38 | 41 | def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: |
39 | 42 | """Serializes a FlatTensor to a flatbuffer and returns the serialized data.""" |
@@ -320,7 +323,7 @@ def serialize( |
320 | 323 | # Create FlatTensor, which describes of the contents of the file and |
321 | 324 | # points to all the data segments. It will be serialized to flatbuffer. |
322 | 325 | flat_tensor = FlatTensor( |
323 | | - version=0, # Keep in sync with c++ version number in serialize.h |
| 326 | + version=_FLAT_TENSOR_VERSION, |
324 | 327 | segments=data_segments, |
325 | 328 | named_data=named_data, |
326 | 329 | ) |
@@ -383,4 +386,46 @@ def deserialize(self, blob: Cord) -> DataPayload: |
383 | 386 | """ |
384 | 387 | Deserializes a flat_tensor blob into a list of tensor metadata and tensors. |
385 | 388 | """ |
386 | | - raise NotImplementedError("deserialize_data") |
| 389 | + |
| 390 | + data = bytes(blob) |
| 391 | + |
| 392 | + # Read header. Verify that it's valid. |
| 393 | + header = FlatTensorHeader.from_bytes(data[8:]) |
| 394 | + if not header.is_valid(): |
| 395 | + raise RuntimeError("Flat tensor header is invalid. File is likely incorrect format or corrupt.") |
| 396 | + |
| 397 | + # Deserialize the flat tensor data, which contains the data offsets and tensor metadata. |
| 398 | + flat_tensor_bytes = data[ |
| 399 | + 0 : header.flatbuffer_offset + header.flatbuffer_size |
| 400 | + ] |
| 401 | + flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes) |
| 402 | + |
| 403 | + # Verify that this is a supported version. |
| 404 | + if flat_tensor.version != _FLAT_TENSOR_VERSION: |
| 405 | + raise NotImplementedError(f"Flat tensor files reports unsupported version {flat_tensor.version}. Expected {_FLAT_TENSOR_VERSION}.") |
| 406 | + |
| 407 | + # Extract the buffers. |
| 408 | + buffers = list( |
| 409 | + data[ |
| 410 | + header.segment_base_offset + segment.offset : |
| 411 | + header.segment_base_offset + segment.offset + segment.size |
| 412 | + ] |
| 413 | + for segment |
| 414 | + in flat_tensor.segments |
| 415 | + ) |
| 416 | + |
| 417 | + payload = DataPayload( |
| 418 | + buffers=buffers, |
| 419 | + named_data={}, |
| 420 | + ) |
| 421 | + |
| 422 | + # Read the named data entries. |
| 423 | + for named_data in flat_tensor.named_data: |
| 424 | + entry = DataEntry( |
| 425 | + buffer_index = named_data.segment_index, |
| 426 | + alignment = 1, |
| 427 | + tensor_layout = named_data.tensor_layout, |
| 428 | + ) |
| 429 | + payload.named_data[named_data.key] = entry |
| 430 | + |
| 431 | + return payload |
0 commit comments