1515)
1616
1717from  executorch .exir .schema  import  ScalarType 
18+ from  executorch .extension .flat_tensor .serialize .flat_tensor_schema  import  TensorMetadata 
1819
1920from  executorch .extension .flat_tensor .serialize .serialize  import  (
21+     _convert_to_flat_tensor ,
22+     FlatTensorConfig ,
2023    FlatTensorHeader ,
2124    FlatTensorSerializer ,
2225)
2326
2427# Test artifacts 
25- TEST_TENSOR_BUFFER  =  [b"tensor"  ]
28+ TEST_TENSOR_BUFFER  =  [b"\x11 "  * 4 ,  b" \x22 " * 32 ]
2629TEST_TENSOR_MAP  =  {
2730    "fqn1" : 0 ,
2831    "fqn2" : 0 ,
32+     "fqn3" : 1 ,
2933}
3034
3135TEST_TENSOR_LAYOUT  =  {
3943        dim_sizes = [1 , 1 , 1 ],
4044        dim_order = typing .cast (List [bytes ], [0 , 1 , 2 ]),
4145    ),
46+     "fqn3" : TensorLayout (
47+         scalar_type = ScalarType .INT ,
48+         dim_sizes = [2 , 2 , 2 ],
49+         dim_order = typing .cast (List [bytes ], [0 , 1 ]),
50+     ),
4251}
4352
4453
4554class  TestSerialize (unittest .TestCase ):
55+     def  check_tensor_metadata (
56+         self , tensor_layout : TensorLayout , tensor_metadata : TensorMetadata 
57+     ) ->  None :
58+         self .assertEqual (tensor_layout .scalar_type , tensor_metadata .scalar_type )
59+         self .assertEqual (tensor_layout .dim_sizes , tensor_metadata .dim_sizes )
60+         self .assertEqual (tensor_layout .dim_order , tensor_metadata .dim_order )
61+ 
4662    def  test_serialize (self ) ->  None :
47-         serializer : DataSerializer  =  FlatTensorSerializer ()
63+         config  =  FlatTensorConfig ()
64+         serializer : DataSerializer  =  FlatTensorSerializer (config )
4865
4966        data  =  bytes (
5067            serializer .serialize_tensors (
@@ -54,14 +71,71 @@ def test_serialize(self) -> None:
5471            )
5572        )
5673
74+         # Check header. 
5775        header  =  FlatTensorHeader .from_bytes (data [0  : FlatTensorHeader .EXPECTED_LENGTH ])
5876        self .assertTrue (header .is_valid ())
5977
6078        self .assertEqual (header .flatbuffer_offset , 48 )
61-         self .assertEqual (header .flatbuffer_size , 200 )
62-         self .assertEqual (header .segment_base_offset , 256 )
63-         self .assertEqual (header .data_size , 16 )
79+         self .assertEqual (header .flatbuffer_size , 288 )
80+         self .assertEqual (header .segment_base_offset , 336 )
81+         self .assertEqual (header .data_size , 48 )
6482
6583        self .assertEqual (
6684            data [header .flatbuffer_offset  +  4  : header .flatbuffer_offset  +  8 ], b"FT01" 
6785        )
86+ 
87+         # Check flat tensor data. 
88+         flat_tensor_bytes  =  data [
89+             header .flatbuffer_offset  : header .flatbuffer_offset  +  header .flatbuffer_size 
90+         ]
91+ 
92+         flat_tensor  =  _convert_to_flat_tensor (flat_tensor_bytes )
93+ 
94+         self .assertEqual (flat_tensor .version , 0 )
95+         self .assertEqual (flat_tensor .tensor_alignment , config .tensor_alignment )
96+ 
97+         tensors  =  flat_tensor .tensors 
98+         self .assertEqual (len (tensors ), 3 )
99+         self .assertEqual (tensors [0 ].fully_qualified_name , "fqn1" )
100+         self .check_tensor_metadata (TEST_TENSOR_LAYOUT ["fqn1" ], tensors [0 ])
101+         self .assertEqual (tensors [0 ].segment_index , 0 )
102+         self .assertEqual (tensors [0 ].offset , 0 )
103+ 
104+         self .assertEqual (tensors [1 ].fully_qualified_name , "fqn2" )
105+         self .check_tensor_metadata (TEST_TENSOR_LAYOUT ["fqn2" ], tensors [1 ])
106+         self .assertEqual (tensors [1 ].segment_index , 0 )
107+         self .assertEqual (tensors [1 ].offset , 0 )
108+ 
109+         self .assertEqual (tensors [2 ].fully_qualified_name , "fqn3" )
110+         self .check_tensor_metadata (TEST_TENSOR_LAYOUT ["fqn3" ], tensors [2 ])
111+         self .assertEqual (tensors [2 ].segment_index , 0 )
112+         self .assertEqual (tensors [2 ].offset , config .tensor_alignment )
113+ 
114+         segments  =  flat_tensor .segments 
115+         self .assertEqual (len (segments ), 1 )
116+         self .assertEqual (segments [0 ].offset , 0 )
117+         self .assertEqual (segments [0 ].size , config .tensor_alignment  *  3 )
118+ 
119+         # Check segment data. 
120+         segment_data  =  data [
121+             header .segment_base_offset  : header .segment_base_offset  +  segments [0 ].size 
122+         ]
123+ 
124+         t0_start  =  0 
125+         t0_len  =  len (TEST_TENSOR_BUFFER [0 ])
126+         t0_end  =  config .tensor_alignment 
127+         self .assertEqual (
128+             segment_data [t0_start  : t0_start  +  t0_len ], TEST_TENSOR_BUFFER [0 ]
129+         )
130+         padding  =  b"\x00 "  *  (t0_end  -  t0_len )
131+         self .assertEqual (segment_data [t0_start  +  t0_len  : t0_end ], padding )
132+ 
133+         t1_start  =  config .tensor_alignment 
134+         t1_len  =  len (TEST_TENSOR_BUFFER [1 ])
135+         t1_end  =  config .tensor_alignment  *  3 
136+         self .assertEqual (
137+             segment_data [t1_start  : t1_start  +  t1_len ],
138+             TEST_TENSOR_BUFFER [1 ],
139+         )
140+         padding  =  b"\x00 "  *  (t1_end  -  (t1_len  +  t1_start ))
141+         self .assertEqual (segment_data [t1_start  +  t1_len  : t1_end ], padding )
0 commit comments