88
99import  copy 
1010import  json 
11+ import  math 
1112import  re 
1213
1314from  dataclasses  import  dataclass 
14- from  typing  import  ClassVar , List , Literal , Optional , Tuple 
15+ from  typing  import  ClassVar , Dict ,  List , Literal , Optional , Tuple 
1516
1617from  executorch .exir ._serialize ._cord  import  Cord 
1718from  executorch .exir ._serialize ._dataclass  import  _DataclassEncoder , _json_to_dataclass 
2021    _program_flatbuffer_to_json ,
2122    _program_json_to_flatbuffer ,
2223)
24+ from  executorch .exir ._serialize ._named_data_store  import  (
25+     BufferEntry ,
26+     NamedDataStoreOutput ,
27+ )
2328
2429from  executorch .exir ._serialize .padding  import  aligned_size , pad_to , padding_required 
2530
2934    Buffer ,
3035    DataLocation ,
3136    DataSegment ,
37+     NamedData ,
3238    Program ,
3339    SubsegmentOffsets ,
3440)
4147_HEADER_BYTEORDER : Literal ["little" ] =  "little" 
4248
4349
50+ @dataclass  
51+ class  AlignedData :
52+     """ 
53+     Holds data that should be aligned, for serialization. 
54+ 
55+     Attributes: 
56+         data: The data to serialize, as a cord. 
57+         alignment: The alignment required for the data. 
58+     """ 
59+ 
60+     data : Cord 
61+     alignment : int 
62+ 
63+     def  __init__ (self , data : Cord , alignment : Optional [int ] =  None ) ->  None :
64+         self .data  =  data 
65+         self .alignment  =  alignment  or  1 
66+ 
67+ 
4468def  _program_to_json (program : Program ) ->  str :
4569    """Returns the JSON representation of the given Program.""" 
4670    return  json .dumps (program , cls = _DataclassEncoder )
@@ -213,7 +237,7 @@ def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:
213237
214238def  _extract_delegate_segments (
215239    program : Program ,
216-     segments : List [Cord ],
240+     segments : List [AlignedData ],
217241) ->  None :
218242    """Extracts the delegate segments inlined in the program into a list of buffers. 
219243        The program is modified in-place to remove the delegate data. 
@@ -253,7 +277,7 @@ def _extract_delegate_segments(
253277                segment_index  =  segment_index_map .get (inline .data )
254278                if  segment_index  is  None :
255279                    segment_index  =  len (segments )
256-                     segments .append (Cord (inline .data ))
280+                     segments .append (AlignedData ( Cord (inline .data ) ))
257281                    segment_index_map [inline .data ] =  segment_index 
258282                delegate .processed  =  BackendDelegateDataReference (
259283                    location = DataLocation .SEGMENT ,
@@ -316,6 +340,44 @@ def _extract_constant_segment(
316340    return  constant_segment_data , constant_segment_offsets 
317341
318342
343+ def  _extract_named_data (
344+     program : Program ,
345+     segments : List [AlignedData ],
346+     buffers : List [BufferEntry ],
347+     name_to_buffer_idx : Dict [str , int ],
348+ ) ->  None :
349+     """Modifies the program in-place to add references to the named data 
350+         segments. 
351+ 
352+     Args: 
353+         program: The program to extract segments from. Modified in-place. 
354+         segments: A list of buffers to append extracted segments to. Modified in-place. 
355+         buffers: A list of unique buffers and the information required to 
356+             serialize them. Not modified. 
357+         name_to_buffer_idx: A map from the name of a blob to the index in buffers. 
358+             Not modified. 
359+     """ 
360+     if  program .named_data  is  not None  and  len (program .named_data ) >  0 :
361+         raise  ValueError ("Program already has named data." )
362+ 
363+     # Map from buffer_idx to segment_idx. 
364+     segment_index_map : Dict [int , int ] =  {}
365+ 
366+     named_data : List [NamedData ] =  []
367+     for  name , buffer_idx  in  name_to_buffer_idx .items ():
368+         segment_index  =  segment_index_map .get (buffer_idx , None )
369+         if  segment_index  is  None :
370+             segment_index  =  len (segments )
371+             segment_index_map [buffer_idx ] =  segment_index 
372+             segments .append (
373+                 AlignedData (
374+                     Cord (buffers [buffer_idx ].buffer ), buffers [buffer_idx ].alignment 
375+                 )
376+             )
377+         named_data .append (NamedData (key = name , segment_index = segment_index ))
378+     program .named_data  =  named_data 
379+ 
380+ 
319381def  serialize_pte_binary (
320382    program : Program ,
321383    * ,
@@ -324,6 +386,7 @@ def serialize_pte_binary(
324386    segment_alignment : int  =  128 ,
325387    constant_tensor_alignment : Optional [int ] =  None ,
326388    delegate_alignment : Optional [int ] =  None ,
389+     named_data : Optional [NamedDataStoreOutput ] =  None ,
327390) ->  Cord :
328391    """Returns the runtime binary representation of the given Program. 
329392
@@ -343,6 +406,8 @@ def serialize_pte_binary(
343406        delegate_alignment: If provided, the minimum alignment of delegate data 
344407            in the program. Must be a power of 2. If not provided, uses the 
345408            value in the schema file. 
409+         named_data: If provided, named blobs to be stored in segments 
410+             after the PTE file. 
346411    Returns: 
347412        The serialized form of the Program, ready for execution by the runtime. 
348413    """ 
@@ -355,8 +420,9 @@ def serialize_pte_binary(
355420    # copy, reusing the actual data blobs. 
356421    program  =  copy .deepcopy (program )
357422
358-     # Store extracted segment data; this may be constant data or delegate data. 
359-     segments : List [Cord ] =  []
423+     # Store extracted segment data, with any buffer-specific alignment. 
424+     # This may be constant data, delegate data or named data. 
425+     segments : List [AlignedData ] =  []
360426
361427    constant_segment_data , constant_segment_offsets  =  _extract_constant_segment (
362428        program .constant_buffer , tensor_alignment = constant_tensor_alignment 
@@ -374,7 +440,7 @@ def serialize_pte_binary(
374440        # Clear the constant buffer, as constant data will be stored in segments. 
375441        program .constant_buffer  =  []
376442        # Add to the aggregate segments cord. 
377-         segments .append (constant_segment_data )
443+         segments .append (AlignedData ( constant_segment_data ) )
378444
379445    if  mutable_data  is  not None :
380446        mutable_segment_data , mutable_segment_offsets  =  _extract_constant_segment (
@@ -389,31 +455,34 @@ def serialize_pte_binary(
389455                ),
390456            ]
391457            # Add to the aggregate segments cord. 
392-             segments .append (mutable_segment_data )
458+             segments .append (AlignedData ( mutable_segment_data ) )
393459
394460    if  extract_delegate_segments :
395461        _extract_delegate_segments (program , segments )
462+     if  named_data  is  not None :
463+         _extract_named_data (program , segments , named_data .buffers , named_data .pte_data )
396464
397465    # Append all segments into a single Cord, adding any necessary padding to ensure that 
398466    # each segment begins at the required alignment. 
399467    # Update program.segments with the offsets to each segment. 
400468    segments_data  =  Cord ()
401-     for  data  in  segments :
469+     for  segment  in  segments :
402470        prev_end  =  (
403471            (program .segments [- 1 ].offset  +  program .segments [- 1 ].size )
404472            if  program .segments 
405473            else  0 
406474        )
475+         alignment  =  math .lcm (segment_alignment , segment .alignment )
407476        program .segments .append (
408477            DataSegment (
409-                 offset = aligned_size (prev_end , segment_alignment ), size = len (data )
478+                 offset = aligned_size (prev_end , alignment ), size = len (segment . data )
410479            )
411480        )
412481        # Add to aggregate segments cord with padding. 
413-         padding_length  =  padding_required (len (segments_data ), segment_alignment )
482+         padding_length  =  padding_required (len (segments_data ), alignment )
414483        if  padding_length  >  0 :
415484            segments_data .append (b"\x00 "  *  padding_length )
416-         segments_data .append (data )
485+         segments_data .append (segment . data )
417486
418487    # Convert to a standard flatbuffer binary. 
419488    result : _FlatbufferResult  =  _program_json_to_flatbuffer (
0 commit comments