88
99import copy
1010import json
11+ import math
1112import re
1213
1314from dataclasses import dataclass
14- from typing import ClassVar , List , Literal , Optional , Tuple
15+ from typing import ClassVar , Dict , List , Literal , Optional , Tuple
1516
1617from executorch .exir ._serialize ._cord import Cord
1718from executorch .exir ._serialize ._dataclass import _DataclassEncoder , _json_to_dataclass
2021 _program_flatbuffer_to_json ,
2122 _program_json_to_flatbuffer ,
2223)
24+ from executorch .exir ._serialize ._named_data_store import (
25+ BufferEntry ,
26+ NamedDataStoreOutput ,
27+ )
2328
2429from executorch .exir ._serialize .padding import aligned_size , pad_to , padding_required
2530
2934 Buffer ,
3035 DataLocation ,
3136 DataSegment ,
37+ NamedData ,
3238 Program ,
3339 SubsegmentOffsets ,
3440)
4147_HEADER_BYTEORDER : Literal ["little" ] = "little"
4248
4349
50+ @dataclass
51+ class AlignedData :
52+ """
53+ Holds data that should be aligned, for serialization.
54+
55+ Attributes:
56+ data: The data to serialize, as a cord.
57+ alignment: The alignment required for the data.
58+ """
59+
60+ data : Cord
61+ alignment : int
62+
63+ def __init__ (self , data : Cord , alignment : Optional [int ] = None ) -> None :
64+ self .data = data
65+ self .alignment = alignment or 1
66+
67+
4468def _program_to_json (program : Program ) -> str :
4569 """Returns the JSON representation of the given Program."""
4670 return json .dumps (program , cls = _DataclassEncoder )
@@ -213,7 +237,7 @@ def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:
213237
214238def _extract_delegate_segments (
215239 program : Program ,
216- segments : List [Cord ],
240+ segments : List [AlignedData ],
217241) -> None :
218242 """Extracts the delegate segments inlined in the program into a list of buffers.
219243 The program is modified in-place to remove the delegate data.
@@ -253,7 +277,7 @@ def _extract_delegate_segments(
253277 segment_index = segment_index_map .get (inline .data )
254278 if segment_index is None :
255279 segment_index = len (segments )
256- segments .append (Cord (inline .data ))
280+ segments .append (AlignedData ( Cord (inline .data ) ))
257281 segment_index_map [inline .data ] = segment_index
258282 delegate .processed = BackendDelegateDataReference (
259283 location = DataLocation .SEGMENT ,
@@ -316,6 +340,44 @@ def _extract_constant_segment(
316340 return constant_segment_data , constant_segment_offsets
317341
318342
343+ def _extract_named_data (
344+ program : Program ,
345+ segments : List [AlignedData ],
346+ buffers : List [BufferEntry ],
347+ name_to_buffer_idx : Dict [str , int ],
348+ ) -> None :
349+ """Modifies the program in-place to add references to the named data
350+ segments.
351+
352+ Args:
353+ program: The program to extract segments from. Modified in-place.
354+ segments: A list of buffers to append extracted segments to. Modified in-place.
355+ buffers: A list of unique buffers and the information required to
356+ serialize them. Not modified.
357+ name_to_buffer_idx: A map from the name of a blob to the index in buffers.
358+ Not modified.
359+ """
360+ if program .named_data is not None and len (program .named_data ) > 0 :
361+ raise ValueError ("Program already has named data." )
362+
363+ # Map from buffer_idx to segment_idx.
364+ segment_index_map : Dict [int , int ] = {}
365+
366+ named_data : List [NamedData ] = []
367+ for name , buffer_idx in name_to_buffer_idx .items ():
368+ segment_index = segment_index_map .get (buffer_idx , None )
369+ if segment_index is None :
370+ segment_index = len (segments )
371+ segment_index_map [buffer_idx ] = segment_index
372+ segments .append (
373+ AlignedData (
374+ Cord (buffers [buffer_idx ].buffer ), buffers [buffer_idx ].alignment
375+ )
376+ )
377+ named_data .append (NamedData (key = name , segment_index = segment_index ))
378+ program .named_data = named_data
379+
380+
319381def serialize_pte_binary (
320382 program : Program ,
321383 * ,
@@ -324,6 +386,7 @@ def serialize_pte_binary(
324386 segment_alignment : int = 128 ,
325387 constant_tensor_alignment : Optional [int ] = None ,
326388 delegate_alignment : Optional [int ] = None ,
389+ named_data : Optional [NamedDataStoreOutput ] = None ,
327390) -> Cord :
328391 """Returns the runtime binary representation of the given Program.
329392
@@ -343,6 +406,8 @@ def serialize_pte_binary(
343406 delegate_alignment: If provided, the minimum alignment of delegate data
344407 in the program. Must be a power of 2. If not provided, uses the
345408 value in the schema file.
409+ named_data: If provided, named blobs to be stored in segments
410+ after the PTE file.
346411 Returns:
347412 The serialized form of the Program, ready for execution by the runtime.
348413 """
@@ -355,8 +420,9 @@ def serialize_pte_binary(
355420 # copy, reusing the actual data blobs.
356421 program = copy .deepcopy (program )
357422
358- # Store extracted segment data; this may be constant data or delegate data.
359- segments : List [Cord ] = []
423+ # Store extracted segment data, with any buffer-specific alignment.
424+ # This may be constant data, delegate data or named data.
425+ segments : List [AlignedData ] = []
360426
361427 constant_segment_data , constant_segment_offsets = _extract_constant_segment (
362428 program .constant_buffer , tensor_alignment = constant_tensor_alignment
@@ -374,7 +440,7 @@ def serialize_pte_binary(
374440 # Clear the constant buffer, as constant data will be stored in segments.
375441 program .constant_buffer = []
376442 # Add to the aggregate segments cord.
377- segments .append (constant_segment_data )
443+ segments .append (AlignedData ( constant_segment_data ) )
378444
379445 if mutable_data is not None :
380446 mutable_segment_data , mutable_segment_offsets = _extract_constant_segment (
@@ -389,31 +455,37 @@ def serialize_pte_binary(
389455 ),
390456 ]
391457 # Add to the aggregate segments cord.
392- segments .append (mutable_segment_data )
458+ segments .append (AlignedData ( mutable_segment_data ) )
393459
394460 if extract_delegate_segments :
395461 _extract_delegate_segments (program , segments )
462+ if named_data is not None :
463+ _extract_named_data (program , segments , named_data .buffers , named_data .pte_data )
396464
397465 # Append all segments into a single Cord, adding any necessary padding to ensure that
398466 # each segment begins at the required alignment.
399467 # Update program.segments with the offsets to each segment.
400468 segments_data = Cord ()
401- for data in segments :
469+ prev_alignment = segment_alignment
470+ for segment in segments :
402471 prev_end = (
403472 (program .segments [- 1 ].offset + program .segments [- 1 ].size )
404473 if program .segments
405474 else 0
406475 )
407476 program .segments .append (
408477 DataSegment (
409- offset = aligned_size (prev_end , segment_alignment ), size = len (data )
478+ offset = aligned_size (prev_end , prev_alignment ), size = len (segment . data )
410479 )
411480 )
412481 # Add to aggregate segments cord with padding.
413- padding_length = padding_required (len (segments_data ), segment_alignment )
482+ padding_length = padding_required (len (segments_data ), prev_alignment )
414483 if padding_length > 0 :
415484 segments_data .append (b"\x00 " * padding_length )
416- segments_data .append (data )
485+ segments_data .append (segment .data )
486+ # Update alignment for next segment. Take the lcm of the segment
487+ # alignment and the segment.alignment.
488+ prev_alignment = math .lcm (segment_alignment , segment .alignment )
417489
418490 # Convert to a standard flatbuffer binary.
419491 result : _FlatbufferResult = _program_json_to_flatbuffer (
0 commit comments