1616from  executorch .exir ._serialize .padding  import  aligned_size 
1717
1818from  executorch .exir .schema  import  ScalarType 
19+ from  executorch .extension .flat_tensor .serialize .flat_tensor_schema  import  TensorMetadata 
1920
2021from  executorch .extension .flat_tensor .serialize .serialize  import  (
22+     _convert_to_flat_tensor ,
2123    FlatTensorConfig ,
2224    FlatTensorHeader ,
2325    FlatTensorSerializer ,
2426)
2527
2628# Test artifacts. 
27- TEST_TENSOR_BUFFER  =  [b"tensor"  ]
29+ TEST_TENSOR_BUFFER  =  [b"\x11 "    *   4 ,  b" \x22 "   *   32 ]
2830TEST_TENSOR_MAP  =  {
2931    "fqn1" : TensorEntry (
3032        buffer_index = 0 ,
4244            dim_order = [0 , 1 , 2 ],
4345        ),
4446    ),
47+     "fqn3" : TensorEntry (
48+         buffer_index = 1 ,
49+         layout = TensorLayout (
50+             scalar_type = ScalarType .INT ,
51+             sizes = [2 , 2 , 2 ],
52+             dim_order = [0 , 1 ],
53+         ),
54+     ),
4555}
4656TEST_DATA_PAYLOAD  =  DataPayload (
4757    buffers = TEST_TENSOR_BUFFER ,
5060
5161
5262class  TestSerialize (unittest .TestCase ):
63+     # TODO(T211851359): improve test coverage. 
64+     def  check_tensor_metadata (
65+         self , tensor_layout : TensorLayout , tensor_metadata : TensorMetadata 
66+     ) ->  None :
67+         self .assertEqual (tensor_layout .scalar_type , tensor_metadata .scalar_type )
68+         self .assertEqual (tensor_layout .sizes , tensor_metadata .sizes )
69+         self .assertEqual (tensor_layout .dim_order , tensor_metadata .dim_order )
70+ 
5371    def  test_serialize (self ) ->  None :
5472        config  =  FlatTensorConfig ()
5573        serializer : DataSerializer  =  FlatTensorSerializer (config )
5674
5775        data  =  bytes (serializer .serialize (TEST_DATA_PAYLOAD ))
5876
77+         # Check header. 
5978        header  =  FlatTensorHeader .from_bytes (data [0  : FlatTensorHeader .EXPECTED_LENGTH ])
6079        self .assertTrue (header .is_valid ())
6180
@@ -75,9 +94,73 @@ def test_serialize(self) -> None:
7594        self .assertTrue (header .segment_base_offset , expected_segment_base_offset )
7695
7796        # TEST_TENSOR_BUFFER is aligned to config.segment_alignment. 
78-         self .assertEqual (header .segment_data_size , config .segment_alignment )
97+         expected_segment_data_size  =  aligned_size (
98+             sum (len (buffer ) for  buffer  in  TEST_TENSOR_BUFFER ), config .segment_alignment 
99+         )
100+         self .assertEqual (header .segment_data_size , expected_segment_data_size )
79101
80102        # Confirm the flatbuffer magic is present. 
81103        self .assertEqual (
82104            data [header .flatbuffer_offset  +  4  : header .flatbuffer_offset  +  8 ], b"FT01" 
83105        )
106+ 
107+         # Check flat tensor data. 
108+         flat_tensor_bytes  =  data [
109+             header .flatbuffer_offset  : header .flatbuffer_offset  +  header .flatbuffer_size 
110+         ]
111+ 
112+         flat_tensor  =  _convert_to_flat_tensor (flat_tensor_bytes )
113+ 
114+         self .assertEqual (flat_tensor .version , 0 )
115+         self .assertEqual (flat_tensor .tensor_alignment , config .tensor_alignment )
116+ 
117+         tensors  =  flat_tensor .tensors 
118+         self .assertEqual (len (tensors ), 3 )
119+         self .assertEqual (tensors [0 ].fully_qualified_name , "fqn1" )
120+         self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn1" ].layout , tensors [0 ])
121+         self .assertEqual (tensors [0 ].segment_index , 0 )
122+         self .assertEqual (tensors [0 ].offset , 0 )
123+ 
124+         self .assertEqual (tensors [1 ].fully_qualified_name , "fqn2" )
125+         self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn2" ].layout , tensors [1 ])
126+         self .assertEqual (tensors [1 ].segment_index , 0 )
127+         self .assertEqual (tensors [1 ].offset , 0 )
128+ 
129+         self .assertEqual (tensors [2 ].fully_qualified_name , "fqn3" )
130+         self .check_tensor_metadata (TEST_TENSOR_MAP ["fqn3" ].layout , tensors [2 ])
131+         self .assertEqual (tensors [2 ].segment_index , 0 )
132+         self .assertEqual (tensors [2 ].offset , config .tensor_alignment )
133+ 
134+         segments  =  flat_tensor .segments 
135+         self .assertEqual (len (segments ), 1 )
136+         self .assertEqual (segments [0 ].offset , 0 )
137+         self .assertEqual (segments [0 ].size , config .tensor_alignment  *  3 )
138+ 
139+         # Check segment data. 
140+         self .assertEqual (
141+             header .segment_base_offset  +  header .segment_data_size , len (data )
142+         )
143+         self .assertTrue (segments [0 ].size  <=  header .segment_data_size )
144+ 
145+         segment_data  =  data [
146+             header .segment_base_offset  : header .segment_base_offset  +  segments [0 ].size 
147+         ]
148+ 
149+         t0_start  =  0 
150+         t0_len  =  len (TEST_TENSOR_BUFFER [0 ])
151+         t0_end  =  config .tensor_alignment 
152+         self .assertEqual (
153+             segment_data [t0_start  : t0_start  +  t0_len ], TEST_TENSOR_BUFFER [0 ]
154+         )
155+         padding  =  b"\x00 "  *  (t0_end  -  t0_len )
156+         self .assertEqual (segment_data [t0_start  +  t0_len  : t0_end ], padding )
157+ 
158+         t1_start  =  config .tensor_alignment 
159+         t1_len  =  len (TEST_TENSOR_BUFFER [1 ])
160+         t1_end  =  config .tensor_alignment  *  3 
161+         self .assertEqual (
162+             segment_data [t1_start  : t1_start  +  t1_len ],
163+             TEST_TENSOR_BUFFER [1 ],
164+         )
165+         padding  =  b"\x00 "  *  (t1_end  -  (t1_len  +  t1_start ))
166+         self .assertEqual (segment_data [t1_start  +  t1_len  : t1_end ], padding )
0 commit comments