77# pyre-strict
88
99import json
10+ import math
1011import os
1112import tempfile
1213from dataclasses import dataclass
1920from executorch .exir ._serialize ._flatbuffer import _flatc_compile , _flatc_decompile
2021from executorch .exir ._serialize ._program import _insert_flatbuffer_header
2122from executorch .exir ._serialize .data_serializer import (
23+ DataEntry ,
2224 DataPayload ,
2325 DataSerializer ,
2426 TensorEntry ,
2931from executorch .extension .flat_tensor .serialize .flat_tensor_schema import (
3032 DataSegment ,
3133 FlatTensor ,
34+ NamedData ,
3235 TensorMetadata ,
3336)
3437
@@ -202,6 +205,24 @@ def to_bytes(self) -> bytes:
202205 return data
203206
204207
208+ @dataclass
209+ class AlignedData :
210+ """
211+ Holds data that should be aligned, for serialization.
212+
213+ Attributes:
214+ data: The data to serialize, as a cord.
215+ alignment: The alignment required for the data.
216+ """
217+
218+ data : Cord
219+ alignment : int
220+
221+ def __init__ (self , data : Cord , alignment : Optional [int ] = None ) -> None :
222+ self .data = data
223+ self .alignment = alignment or 1
224+
225+
205226def _get_extended_header (flat_tensor_data : bytes ) -> Optional [FlatTensorHeader ]:
206227 """Returns the extended header of the flat_tensor data, if present and valid."""
207228 try :
@@ -216,7 +237,7 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
216237def _extract_tensors (
217238 fqn_to_tensor : Dict [str , TensorEntry ],
218239 buffers : Sequence [bytes ],
219- segments : List [Cord ],
240+ segments : List [AlignedData ],
220241 tensor_alignment : int ,
221242) -> List [TensorMetadata ]:
222243 """Places tensors into a single segment, aligned to tensor_alignment within
@@ -265,10 +286,43 @@ def _extract_tensors(
265286 offset = offset ,
266287 )
267288 )
268- segments .append (tensor_data )
289+ segments .append (AlignedData ( tensor_data ) )
269290 return tensors
270291
271292
293+ def _extract_named_data (
294+ key_to_data : Dict [str , DataEntry ],
295+ buffers : Sequence [bytes ],
296+ segments : List [AlignedData ],
297+ ) -> List [NamedData ]:
298+ """Places named data into segments and record the alignment for each.
299+
300+ Args:
301+ key_to_data: A map from keys to opaque data entries.
302+ buffers: A sequence of buffers holding opaque blob data.
303+ segments: A list of segments to append data to. Modified in-place.
304+
305+ Returns:
306+ A list of NamedData describing the offsets to the opaque blob data.
307+ """
308+
309+ # Map from buffer_idx to segment_idx.
310+ segment_index_map : Dict [int , int ] = {}
311+
312+ named_data : List [NamedData ] = []
313+ for key , data_entry in key_to_data .items ():
314+ buffer_idx = data_entry .buffer_index
315+ segment_index = segment_index_map .get (buffer_idx , None )
316+ if segment_index is None :
317+ segment_index = len (segments )
318+ segment_index_map [buffer_idx ] = segment_index
319+ segments .append (
320+ AlignedData (Cord (buffers [buffer_idx ]), data_entry .alignment )
321+ )
322+ named_data .append (NamedData (key = key , segment_index = segment_index ))
323+ return named_data
324+
325+
272326class FlatTensorSerializer (DataSerializer ):
273327 """A concrete implementation of the DataSerializer interface that
274328 serializes and deserializes data to/from the FlatTensor format.
@@ -289,13 +343,14 @@ def serialize(
289343 ) -> Cord :
290344 """Serializes a list of tensors and named data into a blob."""
291345
292- segments : List [Cord ] = []
346+ segments : List [AlignedData ] = []
293347 tensors = _extract_tensors (
294348 data .fqn_to_tensor ,
295349 data .buffers ,
296350 segments ,
297351 self .config .tensor_alignment ,
298352 )
353+ named_data = _extract_named_data (data .key_to_data , data .buffers , segments )
299354
300355 data_segments : List [DataSegment ] = []
301356 segment_data = Cord ()
@@ -305,19 +360,18 @@ def serialize(
305360 if data_segments
306361 else 0
307362 )
363+ alignment = math .lcm (self .config .segment_alignment , segment .alignment )
308364 data_segments .append (
309365 DataSegment (
310- offset = aligned_size (prev_end , self . config . segment_alignment ),
311- size = len (segment ),
366+ offset = aligned_size (prev_end , alignment ),
367+ size = len (segment . data ),
312368 )
313369 )
314370 # Pad segment_data to segment alignment.
315- segment_pad_length = padding_required (
316- len (segment_data ), self .config .segment_alignment
317- )
371+ segment_pad_length = padding_required (len (segment_data ), alignment )
318372 if segment_pad_length > 0 :
319373 segment_data .append (b"\x00 " * segment_pad_length )
320- segment_data .append (segment )
374+ segment_data .append (segment . data )
321375
322376 # Create FlatTensor, which describes of the contents of the file and
323377 # points to all the data segments. It will be serialized to flatbuffer.
@@ -326,7 +380,7 @@ def serialize(
326380 tensor_alignment = self .config .tensor_alignment ,
327381 tensors = tensors ,
328382 segments = data_segments ,
329- named_data = [] ,
383+ named_data = named_data ,
330384 )
331385
332386 flatbuffer_payload = _serialize_to_flatbuffer (flat_tensor )
0 commit comments