88
99import unittest
1010
11+ from typing import List
12+
1113from executorch .exir ._serialize .data_serializer import (
1214 DataPayload ,
1315 DataSerializer ,
1820from executorch .exir ._serialize .padding import aligned_size
1921
2022from executorch .exir .schema import ScalarType
23+ from executorch .extension .flat_tensor .serialize .flat_tensor_schema import TensorMetadata
2124
2225from executorch .extension .flat_tensor .serialize .serialize import (
26+ _deserialize_to_flat_tensor ,
2327 FlatTensorConfig ,
2428 FlatTensorHeader ,
2529 FlatTensorSerializer ,
2630)
2731
2832# Test artifacts.
29- TEST_TENSOR_BUFFER = [b"tensor" ]
33+ TEST_TENSOR_BUFFER : List [ bytes ] = [b"\x11 " * 4 , b" \x22 " * 32 ]
3034TEST_TENSOR_MAP = {
3135 "fqn1" : TensorEntry (
3236 buffer_index = 0 ,
4448 dim_order = [0 , 1 , 2 ],
4549 ),
4650 ),
51+ "fqn3" : TensorEntry (
52+ buffer_index = 1 ,
53+ layout = TensorLayout (
54+ scalar_type = ScalarType .INT ,
55+ sizes = [2 , 2 , 2 ],
56+ dim_order = [0 , 1 ],
57+ ),
58+ ),
4759}
4860TEST_DATA_PAYLOAD = DataPayload (
4961 buffers = TEST_TENSOR_BUFFER ,
5264
5365
5466class TestSerialize (unittest .TestCase ):
67+ # TODO(T211851359): improve test coverage.
68+ def check_tensor_metadata (
69+ self , tensor_layout : TensorLayout , tensor_metadata : TensorMetadata
70+ ) -> None :
71+ self .assertEqual (tensor_layout .scalar_type , tensor_metadata .scalar_type )
72+ self .assertEqual (tensor_layout .sizes , tensor_metadata .sizes )
73+ self .assertEqual (tensor_layout .dim_order , tensor_metadata .dim_order )
74+
5575 def test_serialize (self ) -> None :
5676 config = FlatTensorConfig ()
5777 serializer : DataSerializer = FlatTensorSerializer (config )
5878
59- data = bytes (serializer .serialize (TEST_DATA_PAYLOAD ))
79+ serialized_data = bytes (serializer .serialize (TEST_DATA_PAYLOAD ))
6080
61- header = FlatTensorHeader .from_bytes (data [0 : FlatTensorHeader .EXPECTED_LENGTH ])
81+ # Check header.
82+ header = FlatTensorHeader .from_bytes (
83+ serialized_data [0 : FlatTensorHeader .EXPECTED_LENGTH ]
84+ )
6285 self .assertTrue (header .is_valid ())
6386
6487 # Header is aligned to config.segment_alignment, which is where the flatbuffer starts.
@@ -77,9 +100,86 @@ def test_serialize(self) -> None:
77100 self .assertTrue (header .segment_base_offset , expected_segment_base_offset )
78101
79102 # TEST_TENSOR_BUFFER is aligned to config.segment_alignment.
80- self .assertEqual (header .segment_data_size , config .segment_alignment )
103+ expected_segment_data_size = aligned_size (
104+ sum (len (buffer ) for buffer in TEST_TENSOR_BUFFER ), config .segment_alignment
105+ )
106+ self .assertEqual (header .segment_data_size , expected_segment_data_size )
81107
82108 # Confirm the flatbuffer magic is present.
83109 self .assertEqual (
84- data [header .flatbuffer_offset + 4 : header .flatbuffer_offset + 8 ], b"FT01"
110+ serialized_data [
111+ header .flatbuffer_offset + 4 : header .flatbuffer_offset + 8
112+ ],
113+ b"FT01" ,
114+ )
115+
116+ # Check flat tensor data.
117+ flat_tensor_bytes = serialized_data [
118+ header .flatbuffer_offset : header .flatbuffer_offset + header .flatbuffer_size
119+ ]
120+
121+ flat_tensor = _deserialize_to_flat_tensor (flat_tensor_bytes )
122+
123+ self .assertEqual (flat_tensor .version , 0 )
124+ self .assertEqual (flat_tensor .tensor_alignment , config .tensor_alignment )
125+
126+ tensors = flat_tensor .tensors
127+ self .assertEqual (len (tensors ), 3 )
128+ self .assertEqual (tensors [0 ].fully_qualified_name , "fqn1" )
129+ self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn1" ].layout , tensors [0 ])
130+ self .assertEqual (tensors [0 ].segment_index , 0 )
131+ self .assertEqual (tensors [0 ].offset , 0 )
132+
133+ self .assertEqual (tensors [1 ].fully_qualified_name , "fqn2" )
134+ self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn2" ].layout , tensors [1 ])
135+ self .assertEqual (tensors [1 ].segment_index , 0 )
136+ self .assertEqual (tensors [1 ].offset , 0 )
137+
138+ self .assertEqual (tensors [2 ].fully_qualified_name , "fqn3" )
139+ self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn3" ].layout , tensors [2 ])
140+ self .assertEqual (tensors [2 ].segment_index , 0 )
141+ self .assertEqual (tensors [2 ].offset , config .tensor_alignment )
142+
143+ segments = flat_tensor .segments
144+ self .assertEqual (len (segments ), 1 )
145+ self .assertEqual (segments [0 ].offset , 0 )
146+ self .assertEqual (segments [0 ].size , config .tensor_alignment * 3 )
147+
148+ # Length of serialized_data matches segment_base_offset + segment_data_size.
149+ self .assertEqual (
150+ header .segment_base_offset + header .segment_data_size , len (serialized_data )
151+ )
152+ self .assertTrue (segments [0 ].size <= header .segment_data_size )
153+
154+ # Check the contents of the segment. Expecting two tensors from
155+ # TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32]
156+ segment_data = serialized_data [
157+ header .segment_base_offset : header .segment_base_offset + segments [0 ].size
158+ ]
159+
160+ # Tensor: b"\x11" * 4
161+ t0_start = 0
162+ t0_len = len (TEST_TENSOR_BUFFER [0 ])
163+ t0_end = t0_start + aligned_size (t0_len , config .tensor_alignment )
164+ self .assertEqual (
165+ segment_data [t0_start : t0_start + t0_len ], TEST_TENSOR_BUFFER [0 ]
166+ )
167+ padding = b"\x00 " * (t0_end - t0_len )
168+ self .assertEqual (segment_data [t0_start + t0_len : t0_end ], padding )
169+
170+ # Tensor: b"\x22" * 32
171+ t1_start = t0_end
172+ t1_len = len (TEST_TENSOR_BUFFER [1 ])
173+ t1_end = t1_start + aligned_size (t1_len , config .tensor_alignment )
174+ self .assertEqual (
175+ segment_data [t1_start : t1_start + t1_len ],
176+ TEST_TENSOR_BUFFER [1 ],
177+ )
178+ padding = b"\x00 " * (t1_end - (t1_len + t1_start ))
179+ self .assertEqual (segment_data [t1_start + t1_len : t1_start + t1_end ], padding )
180+
181+ # Check length of the segment is expected.
182+ self .assertEqual (
183+ segments [0 ].size , aligned_size (t1_end , config .segment_alignment )
85184 )
185+ self .assertEqual (segments [0 ].size , header .segment_data_size )
0 commit comments