1010import  os 
1111import  tempfile 
1212from  dataclasses  import  dataclass 
13- from  typing  import  ClassVar , Dict , List , Literal , Optional 
13+ from  typing  import  ClassVar , Dict , List , Literal , Optional ,  Sequence 
1414
1515import  pkg_resources 
1616from  executorch .exir ._serialize ._cord  import  Cord 
1717from  executorch .exir ._serialize ._dataclass  import  _DataclassEncoder , _json_to_dataclass 
1818
1919from  executorch .exir ._serialize ._flatbuffer  import  _flatc_compile , _flatc_decompile 
2020from  executorch .exir ._serialize ._program  import  _insert_flatbuffer_header 
21- from  executorch .exir ._serialize .data_serializer  import  DataPayload , DataSerializer 
21+ from  executorch .exir ._serialize .data_serializer  import  (
22+     DataPayload ,
23+     DataSerializer ,
24+     TensorEntry ,
25+ )
2226
2327from  executorch .exir ._serialize .padding  import  aligned_size , pad_to , padding_required 
2428
25- # Byte order of numbers written to flat tensor headers. Always little-endian 
26- # regardless of the host system, since all commonly-used modern CPUs are little 
27- # endian. 
28- _HEADER_BYTEORDER : Literal ["little" ] =  "little" 
29- 
3029from  executorch .extension .flat_tensor .serialize .flat_tensor_schema  import  (
3130    DataSegment ,
3231    FlatTensor ,
3332    TensorMetadata ,
3433)
3534
35+ # Byte order of numbers written to flat tensor headers. Always little-endian 
36+ # regardless of the host system, since all commonly-used modern CPUs are little 
37+ # endian. 
38+ _HEADER_BYTEORDER : Literal ["little" ] =  "little" 
39+ 
3640
3741def  _serialize_to_flatbuffer (flat_tensor : FlatTensor ) ->  Cord :
3842    """Serializes a FlatTensor to a flatbuffer and returns the serialized data.""" 
@@ -209,6 +213,62 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
209213    return  None 
210214
211215
216+ def  _extract_tensors (
217+     fqn_to_tensor : Dict [str , TensorEntry ],
218+     buffers : Sequence [bytes ],
219+     segments : List [Cord ],
220+     tensor_alignment : int ,
221+ ) ->  List [TensorMetadata ]:
222+     """Places tensors into a single segment, aligned to tensor_alignment within 
223+         the segment. 
224+ 
225+     Args: 
226+         fqn_to_tensor: A map from fully qualified names to tensor entries. 
227+         buffers: A sequence of tensor buffers. 
228+         segments: A list of segments to append the tensor data to. Modified in-place. 
229+         tensor_alignment: The alignment of the tensor data. 
230+ 
231+     Returns: 
232+         A list of TensorMetadata, which describes the tensors in the segment. 
233+     """ 
234+     tensor_data : Cord  =  Cord ()
235+     tensors : List [TensorMetadata ] =  []
236+     # {idx, offset} 
237+     saved_offsets : Dict [int , int ] =  {}
238+     for  fqn , tensor_entry  in  fqn_to_tensor .items ():
239+         assert  tensor_entry .layout  is  not None 
240+         # Check index into the tensor buffers is valid. 
241+         assert  tensor_entry .buffer_index  <  len (
242+             buffers 
243+         ), f"Invalid index { tensor_entry .buffer_index } { len (buffers )}  
244+ 
245+         # Check if the tensor has already been appended to the flat_tensor_data. 
246+         offset  =  saved_offsets .get (tensor_entry .buffer_index , - 1 )
247+         if  offset  ==  - 1 :
248+             if  len (tensor_data ) >  0 :
249+                 # Add padding to round off the previous tensor offset. 
250+                 pad_length  =  padding_required (len (tensor_data ), tensor_alignment )
251+                 tensor_data .append (b"\x00 "  *  pad_length )
252+             # Add to saved offsets. 
253+             offset  =  len (tensor_data )
254+             saved_offsets [tensor_entry .buffer_index ] =  offset 
255+             # Append to flat_tensor_data at the offset. 
256+             tensor_data .append (buffers [tensor_entry .buffer_index ])
257+ 
258+         tensors .append (
259+             TensorMetadata (
260+                 fully_qualified_name = fqn ,
261+                 scalar_type = tensor_entry .layout .scalar_type ,
262+                 sizes = tensor_entry .layout .sizes ,
263+                 dim_order = tensor_entry .layout .dim_order ,
264+                 segment_index = len (segments ),
265+                 offset = offset ,
266+             )
267+         )
268+     segments .append (tensor_data )
269+     return  tensors 
270+ 
271+ 
212272class  FlatTensorSerializer (DataSerializer ):
213273    """A concrete implementation of the DataSerializer interface that 
214274    serializes and deserializes data to/from the FlatTensor format. 
@@ -227,61 +287,45 @@ def serialize(
227287        self ,
228288        data : DataPayload ,
229289    ) ->  Cord :
230-         """Serializes a list of tensor metadata and tensors into a blob.""" 
231- 
232-         flat_tensor_metadata : List [TensorMetadata ] =  []
233-         flat_tensor_data : Cord  =  Cord ()
234- 
235-         # {idx, offset} 
236-         saved_offsets : Dict [int , int ] =  {}
237- 
238-         for  fqn , tensor_entry  in  data .fqn_to_tensor .items ():
239-             assert  tensor_entry .layout  is  not None 
240-             # Check index into the tensor buffers is valid. 
241-             assert  tensor_entry .buffer_index  <  len (
242-                 data .buffers 
243-             ), f"Invalid index { tensor_entry .buffer_index } { len (data .buffers )}  
244- 
245-             # Check if the tensor has already been appended to the flat_tensor_data. 
246-             offset  =  saved_offsets .get (tensor_entry .buffer_index , - 1 )
247-             if  offset  ==  - 1 :
248-                 if  len (flat_tensor_data ) >  0 :
249-                     # Add padding to round off the previous tensor offset. 
250-                     pad_length  =  padding_required (
251-                         len (flat_tensor_data ), self .config .tensor_alignment 
252-                     )
253-                     flat_tensor_data .append (b"\x00 "  *  pad_length )
254-                 # Add to saved offsets. 
255-                 offset  =  len (flat_tensor_data )
256-                 saved_offsets [tensor_entry .buffer_index ] =  offset 
257-                 # Append to flat_tensor_data at the offset. 
258-                 flat_tensor_data .append (data .buffers [tensor_entry .buffer_index ])
259- 
260-             flat_tensor_metadata .append (
261-                 TensorMetadata (
262-                     fully_qualified_name = fqn ,
263-                     scalar_type = tensor_entry .layout .scalar_type ,
264-                     sizes = tensor_entry .layout .sizes ,
265-                     dim_order = tensor_entry .layout .dim_order ,
266-                     segment_index = 0 ,
267-                     offset = offset ,
290+         """Serializes a list of tensors and named data into a blob.""" 
291+ 
292+         segments : List [Cord ] =  []
293+         tensors  =  _extract_tensors (
294+             data .fqn_to_tensor ,
295+             data .buffers ,
296+             segments ,
297+             self .config .tensor_alignment ,
298+         )
299+ 
300+         data_segments : List [DataSegment ] =  []
301+         segment_data  =  Cord ()
302+         for  segment  in  segments :
303+             prev_end  =  (
304+                 (data_segments [- 1 ].offset  +  data_segments [- 1 ].size )
305+                 if  data_segments 
306+                 else  0 
307+             )
308+             data_segments .append (
309+                 DataSegment (
310+                     offset = aligned_size (prev_end , self .config .segment_alignment ),
311+                     size = len (segment ),
268312                )
269313            )
270- 
271-         # Pad flat_tensor_data to segment alignment. 
272-         segment_pad_length   =   padding_required ( 
273-             len ( flat_tensor_data ),  self . config . segment_alignment 
274-         ) 
275-         if   segment_pad_length   >   0 : 
276-             flat_tensor_data .append (b" \x00 "   *   segment_pad_length )
314+              # Pad segment_data to segment alignment. 
315+              segment_pad_length   =   padding_required ( 
316+                  len ( segment_data ),  self . config . segment_alignment 
317+             ) 
318+              if   segment_pad_length   >   0 : 
319+                  segment_data . append ( b" \x00 "   *   segment_pad_length ) 
320+             segment_data .append (segment )
277321
278322        # Create FlatTensor, which describes of the contents of the file and 
279323        # points to all the data segments. It will be serialized to flatbuffer. 
280324        flat_tensor  =  FlatTensor (
281325            version = 0 ,  # Keep in sync with c++ version number in serialize.h 
282326            tensor_alignment = self .config .tensor_alignment ,
283-             tensors = flat_tensor_metadata ,
284-             segments = [ DataSegment ( offset = 0 ,  size = len ( flat_tensor_data ))] ,
327+             tensors = tensors ,
328+             segments = data_segments ,
285329            named_data = [],
286330        )
287331
@@ -307,7 +351,7 @@ def serialize(
307351            flatbuffer_offset = padded_header_length ,
308352            flatbuffer_size = len (flatbuffer_payload ),
309353            segment_base_offset = segment_base_offset ,
310-             segment_data_size = len (flat_tensor_data ),
354+             segment_data_size = len (segment_data ),
311355        ).to_bytes ()
312356
313357        # Pad header and payload to segment alignment. 
@@ -327,15 +371,15 @@ def serialize(
327371        assert  eh .flatbuffer_size  ==  original_flatbuffer_payload_size 
328372        assert  eh .segment_base_offset  ==  segment_base_offset 
329373        assert  eh .flatbuffer_offset  ==  padded_header_length 
330-         assert  eh .segment_data_size  ==  len (flat_tensor_data )
374+         assert  eh .segment_data_size  ==  len (segment_data )
331375
332376        del  header_data 
333377        del  flatbuffer_payload 
334378
335379        # Place everything into one segment. 
336380        payload  =  Cord ()
337381        payload .append (injected_flatbuffer_data )
338-         payload .append (flat_tensor_data )
382+         payload .append (segment_data )
339383
340384        return  payload 
341385
0 commit comments