1010import os
1111import tempfile
1212from dataclasses import dataclass
13- from typing import ClassVar , Dict , List , Literal , Optional
13+ from typing import ClassVar , Dict , List , Literal , Optional , Sequence
1414
1515import pkg_resources
1616from executorch .exir ._serialize ._cord import Cord
1717from executorch .exir ._serialize ._dataclass import _DataclassEncoder , _json_to_dataclass
1818
1919from executorch .exir ._serialize ._flatbuffer import _flatc_compile , _flatc_decompile
2020from executorch .exir ._serialize ._program import _insert_flatbuffer_header
21- from executorch .exir ._serialize .data_serializer import DataPayload , DataSerializer
21+ from executorch .exir ._serialize .data_serializer import (
22+ DataPayload ,
23+ DataSerializer ,
24+ TensorEntry ,
25+ )
2226
2327from executorch .exir ._serialize .padding import aligned_size , pad_to , padding_required
2428
25- # Byte order of numbers written to flat tensor headers. Always little-endian
26- # regardless of the host system, since all commonly-used modern CPUs are little
27- # endian.
28- _HEADER_BYTEORDER : Literal ["little" ] = "little"
29-
3029from executorch .extension .flat_tensor .serialize .flat_tensor_schema import (
3130 DataSegment ,
3231 FlatTensor ,
3332 TensorMetadata ,
3433)
3534
35+ # Byte order of numbers written to flat tensor headers. Always little-endian
36+ # regardless of the host system, since all commonly-used modern CPUs are little
37+ # endian.
38+ _HEADER_BYTEORDER : Literal ["little" ] = "little"
39+
3640
3741def _serialize_to_flatbuffer (flat_tensor : FlatTensor ) -> Cord :
3842 """Serializes a FlatTensor to a flatbuffer and returns the serialized data."""
@@ -209,6 +213,62 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
209213 return None
210214
211215
216+ def _extract_tensors (
217+ fqn_to_tensor : Dict [str , TensorEntry ],
218+ buffers : Sequence [bytes ],
219+ segments : List [Cord ],
220+ tensor_alignment : int ,
221+ ) -> List [TensorMetadata ]:
222+ """Places tensors into a single segment, aligned to tensor_alignment within
223+ the segment.
224+
225+ Args:
226+ fqn_to_tensor: A map from fully qualified names to tensor entries.
227+ buffers: A sequence of tensor buffers.
228+ segments: A list of segments to append the tensor data to. Modified in-place.
229+ tensor_alignment: The alignment of the tensor data.
230+
231+ Returns:
232+ A list of TensorMetadata, which describes the tensors in the segment.
233+ """
234+ tensor_data : Cord = Cord ()
235+ tensors : List [TensorMetadata ] = []
236+ # {idx, offset}
237+ saved_offsets : Dict [int , int ] = {}
238+ for fqn , tensor_entry in fqn_to_tensor .items ():
239+ assert tensor_entry .layout is not None
240+ # Check index into the tensor buffers is valid.
241+ assert tensor_entry .buffer_index < len (
242+ buffers
243+ ), f"Invalid index { tensor_entry .buffer_index } is greater than tensor buffer size { len (buffers )} ."
244+
245+ # Check if the tensor has already been appended to the flat_tensor_data.
246+ offset = saved_offsets .get (tensor_entry .buffer_index , - 1 )
247+ if offset == - 1 :
248+ if len (tensor_data ) > 0 :
249+ # Add padding to round off the previous tensor offset.
250+ pad_length = padding_required (len (tensor_data ), tensor_alignment )
251+ tensor_data .append (b"\x00 " * pad_length )
252+ # Add to saved offsets.
253+ offset = len (tensor_data )
254+ saved_offsets [tensor_entry .buffer_index ] = offset
255+ # Append to flat_tensor_data at the offset.
256+ tensor_data .append (buffers [tensor_entry .buffer_index ])
257+
258+ tensors .append (
259+ TensorMetadata (
260+ fully_qualified_name = fqn ,
261+ scalar_type = tensor_entry .layout .scalar_type ,
262+ sizes = tensor_entry .layout .sizes ,
263+ dim_order = tensor_entry .layout .dim_order ,
264+ segment_index = len (segments ),
265+ offset = offset ,
266+ )
267+ )
268+ segments .append (tensor_data )
269+ return tensors
270+
271+
212272class FlatTensorSerializer (DataSerializer ):
213273 """A concrete implementation of the DataSerializer interface that
214274 serializes and deserializes data to/from the FlatTensor format.
@@ -227,61 +287,45 @@ def serialize(
227287 self ,
228288 data : DataPayload ,
229289 ) -> Cord :
230- """Serializes a list of tensor metadata and tensors into a blob."""
231-
232- flat_tensor_metadata : List [TensorMetadata ] = []
233- flat_tensor_data : Cord = Cord ()
234-
235- # {idx, offset}
236- saved_offsets : Dict [int , int ] = {}
237-
238- for fqn , tensor_entry in data .fqn_to_tensor .items ():
239- assert tensor_entry .layout is not None
240- # Check index into the tensor buffers is valid.
241- assert tensor_entry .buffer_index < len (
242- data .buffers
243- ), f"Invalid index { tensor_entry .buffer_index } is greater than tensor buffer size { len (data .buffers )} ."
244-
245- # Check if the tensor has already been appended to the flat_tensor_data.
246- offset = saved_offsets .get (tensor_entry .buffer_index , - 1 )
247- if offset == - 1 :
248- if len (flat_tensor_data ) > 0 :
249- # Add padding to round off the previous tensor offset.
250- pad_length = padding_required (
251- len (flat_tensor_data ), self .config .tensor_alignment
252- )
253- flat_tensor_data .append (b"\x00 " * pad_length )
254- # Add to saved offsets.
255- offset = len (flat_tensor_data )
256- saved_offsets [tensor_entry .buffer_index ] = offset
257- # Append to flat_tensor_data at the offset.
258- flat_tensor_data .append (data .buffers [tensor_entry .buffer_index ])
259-
260- flat_tensor_metadata .append (
261- TensorMetadata (
262- fully_qualified_name = fqn ,
263- scalar_type = tensor_entry .layout .scalar_type ,
264- sizes = tensor_entry .layout .sizes ,
265- dim_order = tensor_entry .layout .dim_order ,
266- segment_index = 0 ,
267- offset = offset ,
290+ """Serializes a list of tensors and named data into a blob."""
291+
292+ segments : List [Cord ] = []
293+ tensors = _extract_tensors (
294+ data .fqn_to_tensor ,
295+ data .buffers ,
296+ segments ,
297+ self .config .tensor_alignment ,
298+ )
299+
300+ data_segments : List [DataSegment ] = []
301+ segment_data = Cord ()
302+ for segment in segments :
303+ prev_end = (
304+ (data_segments [- 1 ].offset + data_segments [- 1 ].size )
305+ if data_segments
306+ else 0
307+ )
308+ data_segments .append (
309+ DataSegment (
310+ offset = aligned_size (prev_end , self .config .segment_alignment ),
311+ size = len (segment ),
268312 )
269313 )
270-
271- # Pad flat_tensor_data to segment alignment.
272- segment_pad_length = padding_required (
273- len ( flat_tensor_data ), self . config . segment_alignment
274- )
275- if segment_pad_length > 0 :
276- flat_tensor_data .append (b" \x00 " * segment_pad_length )
314+ # Pad segment_data to segment alignment.
315+ segment_pad_length = padding_required (
316+ len ( segment_data ), self . config . segment_alignment
317+ )
318+ if segment_pad_length > 0 :
319+ segment_data . append ( b" \x00 " * segment_pad_length )
320+ segment_data .append (segment )
277321
278322 # Create FlatTensor, which describes of the contents of the file and
279323 # points to all the data segments. It will be serialized to flatbuffer.
280324 flat_tensor = FlatTensor (
281325 version = 0 , # Keep in sync with c++ version number in serialize.h
282326 tensor_alignment = self .config .tensor_alignment ,
283- tensors = flat_tensor_metadata ,
284- segments = [ DataSegment ( offset = 0 , size = len ( flat_tensor_data ))] ,
327+ tensors = tensors ,
328+ segments = data_segments ,
285329 named_data = [],
286330 )
287331
@@ -307,7 +351,7 @@ def serialize(
307351 flatbuffer_offset = padded_header_length ,
308352 flatbuffer_size = len (flatbuffer_payload ),
309353 segment_base_offset = segment_base_offset ,
310- segment_data_size = len (flat_tensor_data ),
354+ segment_data_size = len (segment_data ),
311355 ).to_bytes ()
312356
313357 # Pad header and payload to segment alignment.
@@ -327,15 +371,15 @@ def serialize(
327371 assert eh .flatbuffer_size == original_flatbuffer_payload_size
328372 assert eh .segment_base_offset == segment_base_offset
329373 assert eh .flatbuffer_offset == padded_header_length
330- assert eh .segment_data_size == len (flat_tensor_data )
374+ assert eh .segment_data_size == len (segment_data )
331375
332376 del header_data
333377 del flatbuffer_payload
334378
335379 # Place everything into one segment.
336380 payload = Cord ()
337381 payload .append (injected_flatbuffer_data )
338- payload .append (flat_tensor_data )
382+ payload .append (segment_data )
339383
340384 return payload
341385
0 commit comments