77# pyre-strict
88
99import json
10+ import math
1011import os
1112import tempfile
1213from dataclasses import dataclass
1920from executorch .exir ._serialize ._flatbuffer import _flatc_compile , _flatc_decompile
2021from executorch .exir ._serialize ._program import _insert_flatbuffer_header
2122from executorch .exir ._serialize .data_serializer import (
23+ DataEntry ,
2224 DataPayload ,
2325 DataSerializer ,
2426 TensorEntry ,
2931from executorch .extension .flat_tensor .serialize .flat_tensor_schema import (
3032 DataSegment ,
3133 FlatTensor ,
34+ NamedData ,
3235 TensorMetadata ,
3336)
3437
@@ -202,6 +205,24 @@ def to_bytes(self) -> bytes:
202205 return data
203206
204207
208+ @dataclass
209+ class AlignedData :
210+ """
211+ Holds data that should be aligned, for serialization.
212+
213+ Attributes:
214+ data: The data to serialize, as a cord.
215+ alignment: The alignment required for the data.
216+ """
217+
218+ data : Cord
219+ alignment : int
220+
221+ def __init__ (self , data : Cord , alignment : Optional [int ] = None ) -> None :
222+ self .data = data
223+ self .alignment = alignment or 1
224+
225+
205226def _get_extended_header (flat_tensor_data : bytes ) -> Optional [FlatTensorHeader ]:
206227 """Returns the extended header of the flat_tensor data, if present and valid."""
207228 try :
@@ -216,7 +237,7 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
216237def _extract_tensors (
217238 fqn_to_tensor : Dict [str , TensorEntry ],
218239 buffers : Sequence [bytes ],
219- segments : List [Cord ],
240+ segments : List [AlignedData ],
220241 tensor_alignment : int ,
221242) -> List [TensorMetadata ]:
222243 """Places tensors into a single segment, aligned to tensor_alignment within
@@ -265,10 +286,43 @@ def _extract_tensors(
265286 offset = offset ,
266287 )
267288 )
268- segments .append (tensor_data )
289+ segments .append (AlignedData ( tensor_data ) )
269290 return tensors
270291
271292
293+ def _extract_named_data (
294+ key_to_data : Dict [str , DataEntry ],
295+ buffers : Sequence [bytes ],
296+ segments : List [AlignedData ],
297+ ) -> List [NamedData ]:
298+ """Places named data into segments and record the alignment for each.
299+
300+ Args:
301+ key_to_data: A map from keys to opaque data entries.
302+ buffers: A sequence of buffers holding opaque blob data.
303+ segments: A list of segments to append data to. Modified in-place.
304+
305+ Returns:
306+ A list of NamedData describing the offsets to the opaque blob data.
307+ """
308+
309+ # Map from buffer_idx to segment_idx.
310+ segment_index_map : Dict [int , int ] = {}
311+
312+ named_data : List [NamedData ] = []
313+ for key , data_entry in key_to_data .items ():
314+ buffer_idx = data_entry .buffer_index
315+ segment_index = segment_index_map .get (buffer_idx , None )
316+ if segment_index is None :
317+ segment_index = len (segments )
318+ segment_index_map [buffer_idx ] = segment_index
319+ segments .append (
320+ AlignedData (Cord (buffers [buffer_idx ]), data_entry .alignment )
321+ )
322+ named_data .append (NamedData (key = key , segment_index = segment_index ))
323+ return named_data
324+
325+
272326class FlatTensorSerializer (DataSerializer ):
273327 """A concrete implementation of the DataSerializer interface that
274328 serializes and deserializes data to/from the FlatTensor format.
@@ -289,35 +343,37 @@ def serialize(
289343 ) -> Cord :
290344 """Serializes a list of tensors and named data into a blob."""
291345
292- segments : List [Cord ] = []
346+ segments : List [AlignedData ] = []
293347 tensors = _extract_tensors (
294348 data .fqn_to_tensor ,
295349 data .buffers ,
296350 segments ,
297351 self .config .tensor_alignment ,
298352 )
353+ named_data = _extract_named_data (data .key_to_data , data .buffers , segments )
299354
300355 data_segments : List [DataSegment ] = []
301- segment_data = Cord ()
356+ aggregated_segment_data = Cord ()
302357 for segment in segments :
303358 prev_end = (
304359 (data_segments [- 1 ].offset + data_segments [- 1 ].size )
305360 if data_segments
306361 else 0
307362 )
363+ alignment = math .lcm (self .config .segment_alignment , segment .alignment )
308364 data_segments .append (
309365 DataSegment (
310- offset = aligned_size (prev_end , self . config . segment_alignment ),
311- size = len (segment ),
366+ offset = aligned_size (prev_end , alignment ),
367+ size = len (segment . data ),
312368 )
313369 )
314- # Pad segment_data to segment alignment.
370+ # Pad aggregated_segment_data to segment alignment.
315371 segment_pad_length = padding_required (
316- len (segment_data ), self . config . segment_alignment
372+ len (aggregated_segment_data ), alignment
317373 )
318374 if segment_pad_length > 0 :
319- segment_data .append (b"\x00 " * segment_pad_length )
320- segment_data .append (segment )
375+ aggregated_segment_data .append (b"\x00 " * segment_pad_length )
376+ aggregated_segment_data .append (segment . data )
321377
322378 # Create FlatTensor, which describes of the contents of the file and
323379 # points to all the data segments. It will be serialized to flatbuffer.
@@ -326,7 +382,7 @@ def serialize(
326382 tensor_alignment = self .config .tensor_alignment ,
327383 tensors = tensors ,
328384 segments = data_segments ,
329- named_data = [] ,
385+ named_data = named_data ,
330386 )
331387
332388 flatbuffer_payload = _serialize_to_flatbuffer (flat_tensor )
@@ -351,7 +407,7 @@ def serialize(
351407 flatbuffer_offset = padded_header_length ,
352408 flatbuffer_size = len (flatbuffer_payload ),
353409 segment_base_offset = segment_base_offset ,
354- segment_data_size = len (segment_data ),
410+ segment_data_size = len (aggregated_segment_data ),
355411 ).to_bytes ()
356412
357413 # Pad header and payload to segment alignment.
@@ -371,15 +427,15 @@ def serialize(
371427 assert eh .flatbuffer_size == original_flatbuffer_payload_size
372428 assert eh .segment_base_offset == segment_base_offset
373429 assert eh .flatbuffer_offset == padded_header_length
374- assert eh .segment_data_size == len (segment_data )
430+ assert eh .segment_data_size == len (aggregated_segment_data )
375431
376432 del header_data
377433 del flatbuffer_payload
378434
379435 # Place everything into one segment.
380436 payload = Cord ()
381437 payload .append (injected_flatbuffer_data )
382- payload .append (segment_data )
438+ payload .append (aggregated_segment_data )
383439
384440 return payload
385441
0 commit comments