1616from  executorch .exir ._serialize .padding  import  aligned_size 
1717
1818from  executorch .exir .schema  import  ScalarType 
19+ from  executorch .extension .flat_tensor .serialize .flat_tensor_schema  import  TensorMetadata 
1920
2021from  executorch .extension .flat_tensor .serialize .serialize  import  (
22+     _deserialize_to_flat_tensor ,
2123    FlatTensorConfig ,
2224    FlatTensorHeader ,
2325    FlatTensorSerializer ,
2426)
2527
2628# Test artifacts. 
27- TEST_TENSOR_BUFFER  =  [b"tensor"  ]
29+ TEST_TENSOR_BUFFER  =  [b"\x11 "    *   4 ,  b" \x22 "   *   32 ]
2830TEST_TENSOR_MAP  =  {
2931    "fqn1" : TensorEntry (
3032        buffer_index = 0 ,
4244            dim_order = [0 , 1 , 2 ],
4345        ),
4446    ),
47+     "fqn3" : TensorEntry (
48+         buffer_index = 1 ,
49+         layout = TensorLayout (
50+             scalar_type = ScalarType .INT ,
51+             sizes = [2 , 2 , 2 ],
52+             dim_order = [0 , 1 ],
53+         ),
54+     ),
4555}
4656TEST_DATA_PAYLOAD  =  DataPayload (
4757    buffers = TEST_TENSOR_BUFFER ,
5060
5161
5262class  TestSerialize (unittest .TestCase ):
63+     # TODO(T211851359): improve test coverage. 
64+     def  check_tensor_metadata (
65+         self , tensor_layout : TensorLayout , tensor_metadata : TensorMetadata 
66+     ) ->  None :
67+         self .assertEqual (tensor_layout .scalar_type , tensor_metadata .scalar_type )
68+         self .assertEqual (tensor_layout .sizes , tensor_metadata .sizes )
69+         self .assertEqual (tensor_layout .dim_order , tensor_metadata .dim_order )
70+ 
5371    def  test_serialize (self ) ->  None :
5472        config  =  FlatTensorConfig ()
5573        serializer : DataSerializer  =  FlatTensorSerializer (config )
5674
57-         data  =  bytes (serializer .serialize (TEST_DATA_PAYLOAD ))
75+         serialized_data  =  bytes (serializer .serialize (TEST_DATA_PAYLOAD ))
5876
59-         header  =  FlatTensorHeader .from_bytes (data [0  : FlatTensorHeader .EXPECTED_LENGTH ])
77+         # Check header. 
78+         header  =  FlatTensorHeader .from_bytes (serialized_data [0  : FlatTensorHeader .EXPECTED_LENGTH ])
6079        self .assertTrue (header .is_valid ())
6180
6281        # Header is aligned to config.segment_alignment, which is where the flatbuffer starts. 
@@ -75,9 +94,81 @@ def test_serialize(self) -> None:
7594        self .assertTrue (header .segment_base_offset , expected_segment_base_offset )
7695
7796        # TEST_TENSOR_BUFFER is aligned to config.segment_alignment. 
78-         self .assertEqual (header .segment_data_size , config .segment_alignment )
97+         expected_segment_data_size  =  aligned_size (
98+             sum (len (buffer ) for  buffer  in  TEST_TENSOR_BUFFER ), config .segment_alignment 
99+         )
100+         self .assertEqual (header .segment_data_size , expected_segment_data_size )
79101
80102        # Confirm the flatbuffer magic is present. 
81103        self .assertEqual (
82-             data [header .flatbuffer_offset  +  4  : header .flatbuffer_offset  +  8 ], b"FT01" 
104+             serialized_data [header .flatbuffer_offset  +  4  : header .flatbuffer_offset  +  8 ], b"FT01" 
83105        )
106+ 
107+         # Check flat tensor data. 
108+         flat_tensor_bytes  =  serialized_data [
109+             header .flatbuffer_offset  : header .flatbuffer_offset  +  header .flatbuffer_size 
110+         ]
111+ 
112+         flat_tensor  =  _deserialize_to_flat_tensor (flat_tensor_bytes )
113+ 
114+         self .assertEqual (flat_tensor .version , 0 )
115+         self .assertEqual (flat_tensor .tensor_alignment , config .tensor_alignment )
116+ 
117+         tensors  =  flat_tensor .tensors 
118+         self .assertEqual (len (tensors ), 3 )
119+         self .assertEqual (tensors [0 ].fully_qualified_name , "fqn1" )
120+         self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn1" ].layout , tensors [0 ])
121+         self .assertEqual (tensors [0 ].segment_index , 0 )
122+         self .assertEqual (tensors [0 ].offset , 0 )
123+ 
124+         self .assertEqual (tensors [1 ].fully_qualified_name , "fqn2" )
125+         self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn2" ].layout , tensors [1 ])
126+         self .assertEqual (tensors [1 ].segment_index , 0 )
127+         self .assertEqual (tensors [1 ].offset , 0 )
128+ 
129+         self .assertEqual (tensors [2 ].fully_qualified_name , "fqn3" )
130+         self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn3" ].layout , tensors [2 ])
131+         self .assertEqual (tensors [2 ].segment_index , 0 )
132+         self .assertEqual (tensors [2 ].offset , config .tensor_alignment )
133+ 
134+         segments  =  flat_tensor .segments 
135+         self .assertEqual (len (segments ), 1 )
136+         self .assertEqual (segments [0 ].offset , 0 )
137+         self .assertEqual (segments [0 ].size , config .tensor_alignment  *  3 )
138+ 
139+         # Length of serialized_data matches segment_base_offset + segment_data_size. 
140+         self .assertEqual (
141+             header .segment_base_offset  +  header .segment_data_size , len (serialized_data )
142+         )
143+         self .assertTrue (segments [0 ].size  <=  header .segment_data_size )
144+ 
145+         # Check the contents of the segment. Expecting two tensors from 
146+         # TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32] 
147+         segment_data  =  serialized_data [
148+             header .segment_base_offset  : header .segment_base_offset  +  segments [0 ].size 
149+         ]
150+ 
151+         # Tensor: b"\x11" * 4 
152+         t0_start  =  0 
153+         t0_len  =  len (TEST_TENSOR_BUFFER [0 ])
154+         t0_end  =  t0_start  +  aligned_size (t0_len , config .tensor_alignment )
155+         self .assertEqual (
156+             segment_data [t0_start  : t0_start  +  t0_len ], TEST_TENSOR_BUFFER [0 ]
157+         )
158+         padding  =  b"\x00 "  *  (t0_end  -  t0_len )
159+         self .assertEqual (segment_data [t0_start  +  t0_len  : t0_end ], padding )
160+ 
161+         # Tensor: b"\x22" * 32 
162+         t1_start  =  t0_end 
163+         t1_len  =  len (TEST_TENSOR_BUFFER [1 ])
164+         t1_end  =  t1_start  +  aligned_size (t1_len , config .tensor_alignment )
165+         self .assertEqual (
166+             segment_data [t1_start  : t1_start  +  t1_len ],
167+             TEST_TENSOR_BUFFER [1 ],
168+         )
169+         padding  =  b"\x00 "  *  (t1_end  -  (t1_len  +  t1_start ))
170+         self .assertEqual (segment_data [t1_start  +  t1_len  : t1_start  +  t1_end ], padding )
171+ 
172+         # Check length of the segment is expected. 
173+         self .assertEqual (segments [0 ].size , aligned_size (t1_end , config .segment_alignment ))
174+         self .assertEqual (segments [0 ].size , header .segment_data_size )
0 commit comments