88
99import  unittest 
1010
11+ from  typing  import  List 
12+ 
1113from  executorch .exir ._serialize .data_serializer  import  (
1214    DataPayload ,
1315    DataSerializer ,
1820from  executorch .exir ._serialize .padding  import  aligned_size 
1921
2022from  executorch .exir .schema  import  ScalarType 
23+ from  executorch .extension .flat_tensor .serialize .flat_tensor_schema  import  TensorMetadata 
2124
2225from  executorch .extension .flat_tensor .serialize .serialize  import  (
26+     _deserialize_to_flat_tensor ,
2327    FlatTensorConfig ,
2428    FlatTensorHeader ,
2529    FlatTensorSerializer ,
2630)
2731
2832# Test artifacts. 
29- TEST_TENSOR_BUFFER   =  [b"tensor"  ]
33+ TEST_TENSOR_BUFFER :  List [ bytes ]  =  [b"\x11 "    *   4 ,  b" \x22 "   *   32 ]
3034TEST_TENSOR_MAP  =  {
3135    "fqn1" : TensorEntry (
3236        buffer_index = 0 ,
4448            dim_order = [0 , 1 , 2 ],
4549        ),
4650    ),
51+     "fqn3" : TensorEntry (
52+         buffer_index = 1 ,
53+         layout = TensorLayout (
54+             scalar_type = ScalarType .INT ,
55+             sizes = [2 , 2 , 2 ],
56+             dim_order = [0 , 1 ],
57+         ),
58+     ),
4759}
4860TEST_DATA_PAYLOAD  =  DataPayload (
4961    buffers = TEST_TENSOR_BUFFER ,
5264
5365
5466class  TestSerialize (unittest .TestCase ):
67+     # TODO(T211851359): improve test coverage. 
68+     def  check_tensor_metadata (
69+         self , tensor_layout : TensorLayout , tensor_metadata : TensorMetadata 
70+     ) ->  None :
71+         self .assertEqual (tensor_layout .scalar_type , tensor_metadata .scalar_type )
72+         self .assertEqual (tensor_layout .sizes , tensor_metadata .sizes )
73+         self .assertEqual (tensor_layout .dim_order , tensor_metadata .dim_order )
74+ 
5575    def  test_serialize (self ) ->  None :
5676        config  =  FlatTensorConfig ()
5777        serializer : DataSerializer  =  FlatTensorSerializer (config )
5878
59-         data  =  bytes (serializer .serialize (TEST_DATA_PAYLOAD ))
79+         serialized_data  =  bytes (serializer .serialize (TEST_DATA_PAYLOAD ))
6080
61-         header  =  FlatTensorHeader .from_bytes (data [0  : FlatTensorHeader .EXPECTED_LENGTH ])
81+         # Check header. 
82+         header  =  FlatTensorHeader .from_bytes (
83+             serialized_data [0  : FlatTensorHeader .EXPECTED_LENGTH ]
84+         )
6285        self .assertTrue (header .is_valid ())
6386
6487        # Header is aligned to config.segment_alignment, which is where the flatbuffer starts. 
@@ -77,9 +100,86 @@ def test_serialize(self) -> None:
77100        self .assertTrue (header .segment_base_offset , expected_segment_base_offset )
78101
79102        # TEST_TENSOR_BUFFER is aligned to config.segment_alignment. 
80-         self .assertEqual (header .segment_data_size , config .segment_alignment )
103+         expected_segment_data_size  =  aligned_size (
104+             sum (len (buffer ) for  buffer  in  TEST_TENSOR_BUFFER ), config .segment_alignment 
105+         )
106+         self .assertEqual (header .segment_data_size , expected_segment_data_size )
81107
82108        # Confirm the flatbuffer magic is present. 
83109        self .assertEqual (
84-             data [header .flatbuffer_offset  +  4  : header .flatbuffer_offset  +  8 ], b"FT01" 
110+             serialized_data [
111+                 header .flatbuffer_offset  +  4  : header .flatbuffer_offset  +  8 
112+             ],
113+             b"FT01" ,
114+         )
115+ 
116+         # Check flat tensor data. 
117+         flat_tensor_bytes  =  serialized_data [
118+             header .flatbuffer_offset  : header .flatbuffer_offset  +  header .flatbuffer_size 
119+         ]
120+ 
121+         flat_tensor  =  _deserialize_to_flat_tensor (flat_tensor_bytes )
122+ 
123+         self .assertEqual (flat_tensor .version , 0 )
124+         self .assertEqual (flat_tensor .tensor_alignment , config .tensor_alignment )
125+ 
126+         tensors  =  flat_tensor .tensors 
127+         self .assertEqual (len (tensors ), 3 )
128+         self .assertEqual (tensors [0 ].fully_qualified_name , "fqn1" )
129+         self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn1" ].layout , tensors [0 ])
130+         self .assertEqual (tensors [0 ].segment_index , 0 )
131+         self .assertEqual (tensors [0 ].offset , 0 )
132+ 
133+         self .assertEqual (tensors [1 ].fully_qualified_name , "fqn2" )
134+         self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn2" ].layout , tensors [1 ])
135+         self .assertEqual (tensors [1 ].segment_index , 0 )
136+         self .assertEqual (tensors [1 ].offset , 0 )
137+ 
138+         self .assertEqual (tensors [2 ].fully_qualified_name , "fqn3" )
139+         self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn3" ].layout , tensors [2 ])
140+         self .assertEqual (tensors [2 ].segment_index , 0 )
141+         self .assertEqual (tensors [2 ].offset , config .tensor_alignment )
142+ 
143+         segments  =  flat_tensor .segments 
144+         self .assertEqual (len (segments ), 1 )
145+         self .assertEqual (segments [0 ].offset , 0 )
146+         self .assertEqual (segments [0 ].size , config .tensor_alignment  *  3 )
147+ 
148+         # Length of serialized_data matches segment_base_offset + segment_data_size. 
149+         self .assertEqual (
150+             header .segment_base_offset  +  header .segment_data_size , len (serialized_data )
151+         )
152+         self .assertTrue (segments [0 ].size  <=  header .segment_data_size )
153+ 
154+         # Check the contents of the segment. Expecting two tensors from 
155+         # TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32] 
156+         segment_data  =  serialized_data [
157+             header .segment_base_offset  : header .segment_base_offset  +  segments [0 ].size 
158+         ]
159+ 
160+         # Tensor: b"\x11" * 4 
161+         t0_start  =  0 
162+         t0_len  =  len (TEST_TENSOR_BUFFER [0 ])
163+         t0_end  =  t0_start  +  aligned_size (t0_len , config .tensor_alignment )
164+         self .assertEqual (
165+             segment_data [t0_start  : t0_start  +  t0_len ], TEST_TENSOR_BUFFER [0 ]
166+         )
167+         padding  =  b"\x00 "  *  (t0_end  -  t0_len )
168+         self .assertEqual (segment_data [t0_start  +  t0_len  : t0_end ], padding )
169+ 
170+         # Tensor: b"\x22" * 32 
171+         t1_start  =  t0_end 
172+         t1_len  =  len (TEST_TENSOR_BUFFER [1 ])
173+         t1_end  =  t1_start  +  aligned_size (t1_len , config .tensor_alignment )
174+         self .assertEqual (
175+             segment_data [t1_start  : t1_start  +  t1_len ],
176+             TEST_TENSOR_BUFFER [1 ],
177+         )
178+         padding  =  b"\x00 "  *  (t1_end  -  (t1_len  +  t1_start ))
179+         self .assertEqual (segment_data [t1_start  +  t1_len  : t1_start  +  t1_end ], padding )
180+ 
181+         # Check length of the segment is expected. 
182+         self .assertEqual (
183+             segments [0 ].size , aligned_size (t1_end , config .segment_alignment )
85184        )
185+         self .assertEqual (segments [0 ].size , header .segment_data_size )
0 commit comments