1111import  re 
1212
1313from  dataclasses  import  dataclass 
14- from  typing  import  ClassVar , List , Literal ,  Optional , Tuple 
14+ from  typing  import  ClassVar , List , Optional , Tuple 
1515
1616from  executorch .exir ._serialize ._cord  import  Cord 
1717from  executorch .exir ._serialize ._dataclass  import  _DataclassEncoder , _json_to_dataclass 
2121    _program_json_to_flatbuffer ,
2222)
2323
24+ from  executorch .exir ._serialize .utils  import  (
25+     aligned_size ,
26+     HEADER_BYTEORDER ,
27+     pad_to ,
28+     padding_required ,
29+ )
30+ 
2431from  executorch .exir .schema  import  (
2532    BackendDelegateDataReference ,
2633    BackendDelegateInlineData ,
3340from  executorch .exir .tensor  import  ALIGNMENT 
3441
3542
36- # Byte order of numbers written to program headers. Always little-endian 
37- # regardless of the host system, since all commonly-used modern CPUs are little 
38- # endian. 
39- _HEADER_BYTEORDER : Literal ["little" ] =  "little" 
40- 
41- 
4243def  _program_to_json (program : Program ) ->  str :
4344    """Returns the JSON representation of the given Program.""" 
4445    return  json .dumps (program , cls = _DataclassEncoder )
@@ -50,19 +51,6 @@ def _json_to_program(program_json: bytes) -> Program:
5051    return  _json_to_dataclass (json .loads (program_json ), cls = Program )
5152
5253
53- def  _padding_required (offset : int , alignment : int ) ->  int :
54-     """Returns the padding required to align `offset` to `alignment`.""" 
55-     remainder : int  =  offset  %  alignment 
56-     if  remainder  !=  0 :
57-         return  alignment  -  remainder 
58-     return  0 
59- 
60- 
61- def  _aligned_size (input_size : int , alignment : int ) ->  int :
62-     """Returns input_size padded up to the next whole multiple of alignment.""" 
63-     return  input_size  +  _padding_required (input_size , alignment )
64- 
65- 
6654def  _insert_flatbuffer_header (
6755    flatbuffer_data : bytes , magic_regex : str , header_data : bytes 
6856) ->  bytes :
@@ -102,11 +90,11 @@ def _insert_flatbuffer_header(
10290        return  flatbuffer_data 
10391
10492    # We will need to adjust the root object offset after inserting the header. 
105-     root_offset  =  int .from_bytes (flatbuffer_data [0 :4 ], byteorder = _HEADER_BYTEORDER )
93+     root_offset  =  int .from_bytes (flatbuffer_data [0 :4 ], byteorder = HEADER_BYTEORDER )
10694
10795    return  (
10896        # New root offset. 
109-         (root_offset  +  len (header_data )).to_bytes (4 , byteorder = _HEADER_BYTEORDER )
97+         (root_offset  +  len (header_data )).to_bytes (4 , byteorder = HEADER_BYTEORDER )
11098        # Existing magic bytes. 
11199        +  flatbuffer_data [4 :8 ]
112100        # Provided header + padding. 
@@ -171,11 +159,9 @@ def from_bytes(data: bytes) -> "_ExtendedHeader":
171159
172160        return  _ExtendedHeader (
173161            magic = data [0 :4 ],
174-             length = int .from_bytes (data [4 :8 ], byteorder = _HEADER_BYTEORDER ),
175-             program_size = int .from_bytes (data [8 :16 ], byteorder = _HEADER_BYTEORDER ),
176-             segment_base_offset = int .from_bytes (
177-                 data [16 :24 ], byteorder = _HEADER_BYTEORDER 
178-             ),
162+             length = int .from_bytes (data [4 :8 ], byteorder = HEADER_BYTEORDER ),
163+             program_size = int .from_bytes (data [8 :16 ], byteorder = HEADER_BYTEORDER ),
164+             segment_base_offset = int .from_bytes (data [16 :24 ], byteorder = HEADER_BYTEORDER ),
179165        )
180166
181167    def  is_valid (self ) ->  bool :
@@ -201,35 +187,16 @@ def to_bytes(self) -> bytes:
201187            # fields to this header in the future. Always use the proper size 
202188            # (i.e., ignore self.length) since there's no reason to create an 
203189            # invalid header. 
204-             +  self .EXPECTED_LENGTH .to_bytes (4 , byteorder = _HEADER_BYTEORDER )
190+             +  self .EXPECTED_LENGTH .to_bytes (4 , byteorder = HEADER_BYTEORDER )
205191            # uint64_t: Size of the flatbuffer data, including this header. 
206-             +  self .program_size .to_bytes (8 , byteorder = _HEADER_BYTEORDER )
192+             +  self .program_size .to_bytes (8 , byteorder = HEADER_BYTEORDER )
207193            # uint64_t: Offset to the start of the first segment, or zero if 
208194            # there are no segments. 
209-             +  self .segment_base_offset .to_bytes (8 , byteorder = _HEADER_BYTEORDER )
195+             +  self .segment_base_offset .to_bytes (8 , byteorder = HEADER_BYTEORDER )
210196        )
211197        return  data 
212198
213199
214- def  _pad_to (data : bytes , length : int ) ->  bytes :
215-     """Returns the input followed by enough zero bytes to become the requested length. 
216- 
217-     Args: 
218-         data: The data to pad. 
219-         length: The length of the returned data. 
220-     Returns: 
221-         The padded data. 
222-     Raises: 
223-         ValueError: If the requested length is less than the input length. 
224-     """ 
225-     if  length  <  len (data ):
226-         raise  ValueError (f"Data length { len (data )} { length }  )
227-     if  length  >  len (data ):
228-         data  =  data  +  b"\x00 "  *  (length  -  len (data ))
229-     assert  len (data ) ==  length 
230-     return  data 
231- 
232- 
233200def  _get_extended_header (program_data : bytes ) ->  Optional [_ExtendedHeader ]:
234201    """Returns the extended header of the program data, if present and valid.""" 
235202    try :
@@ -330,7 +297,7 @@ def _extract_constant_segment(
330297        constant_segment_data .append (buffer .storage )
331298        buffer_length  =  len (buffer .storage )
332299        pad_length  =  (
333-             _padding_required (buffer_length , tensor_alignment )
300+             padding_required (buffer_length , tensor_alignment )
334301            if  tensor_alignment  is  not None 
335302            else  0 
336303        )
@@ -432,11 +399,11 @@ def serialize_pte_binary(
432399        )
433400        program .segments .append (
434401            DataSegment (
435-                 offset = _aligned_size (prev_end , segment_alignment ), size = len (data )
402+                 offset = aligned_size (prev_end , segment_alignment ), size = len (data )
436403            )
437404        )
438405        # Add to aggregate segments cord with padding. 
439-         padding_length  =  _padding_required (len (segments_data ), segment_alignment )
406+         padding_length  =  padding_required (len (segments_data ), segment_alignment )
440407        if  padding_length  >  0 :
441408            segments_data .append (b"\x00 "  *  padding_length )
442409        segments_data .append (data )
@@ -454,15 +421,15 @@ def serialize_pte_binary(
454421
455422    # Size of the header to insert. Its size is padded to the largest 
456423    # force_align value present in the schema. 
457-     padded_header_length : int  =  _aligned_size (
424+     padded_header_length : int  =  aligned_size (
458425        input_size = _ExtendedHeader .EXPECTED_LENGTH ,
459426        alignment = result .max_alignment ,
460427    )
461428    # Size of the program with the header inserted. 
462429    program_size : int  =  padded_header_length  +  len (result .data )
463430    # Offset to the first segment, or zero if there are no segments. 
464431    segment_base_offset : int  =  (
465-         _aligned_size (input_size = program_size , alignment = segment_alignment )
432+         aligned_size (input_size = program_size , alignment = segment_alignment )
466433        if  len (segments_data ) >  0 
467434        else  0 
468435    )
@@ -471,7 +438,7 @@ def serialize_pte_binary(
471438    header_data : bytes  =  _ExtendedHeader (
472439        program_size = program_size , segment_base_offset = segment_base_offset 
473440    ).to_bytes ()
474-     header_data  =  _pad_to (header_data , padded_header_length )
441+     header_data  =  pad_to (header_data , padded_header_length )
475442
476443    # Insert the header into the flatbuffer data. 
477444    program_data : bytes  =  _insert_flatbuffer_header (
@@ -496,7 +463,7 @@ def serialize_pte_binary(
496463    # - segments data (optional); aligned to segment_alignment. 
497464    pte_data  =  Cord (program_data )
498465    if  len (segments_data ) >  0 :
499-         padding_length  =  _padding_required (len (pte_data ), segment_alignment )
466+         padding_length  =  padding_required (len (pte_data ), segment_alignment )
500467        pte_data .append (b"\x00 "  *  padding_length )
501468        # The first segment after program data should start at the segment base offset. 
502469        assert  (
0 commit comments