1515)
1616
1717from executorch .exir .schema import ScalarType
18+ from executorch .extension .flat_tensor .serialize .flat_tensor_schema import TensorMetadata
1819
1920from executorch .extension .flat_tensor .serialize .serialize import (
21+ _convert_to_flat_tensor ,
22+ FlatTensorConfig ,
2023 FlatTensorHeader ,
2124 FlatTensorSerializer ,
2225)
2326
2427# Test artifacts
25- TEST_TENSOR_BUFFER = [b"tensor" ]
28+ TEST_TENSOR_BUFFER = [b"\x11 " * 4 , b" \x22 " * 32 ]
2629TEST_TENSOR_MAP = {
2730 "fqn1" : 0 ,
2831 "fqn2" : 0 ,
32+ "fqn3" : 1 ,
2933}
3034
3135TEST_TENSOR_LAYOUT = {
3943 sizes = [1 , 1 , 1 ],
4044 dim_order = typing .cast (List [bytes ], [0 , 1 , 2 ]),
4145 ),
46+ "fqn3" : TensorLayout (
47+ scalar_type = ScalarType .INT ,
48+ sizes = [2 , 2 , 2 ],
49+ dim_order = typing .cast (List [bytes ], [0 , 1 ]),
50+ ),
4251}
4352
4453
4554class TestSerialize (unittest .TestCase ):
55+ def check_tensor_metadata (
56+ self , tensor_layout : TensorLayout , tensor_metadata : TensorMetadata
57+ ) -> None :
58+ self .assertEqual (tensor_layout .scalar_type , tensor_metadata .scalar_type )
59+ self .assertEqual (tensor_layout .sizes , tensor_metadata .sizes )
60+ self .assertEqual (tensor_layout .dim_order , tensor_metadata .dim_order )
61+
4662 def test_serialize (self ) -> None :
47- serializer : DataSerializer = FlatTensorSerializer ()
63+ config = FlatTensorConfig ()
64+ serializer : DataSerializer = FlatTensorSerializer (config )
4865
4966 data = bytes (
5067 serializer .serialize_tensors (
@@ -54,14 +71,71 @@ def test_serialize(self) -> None:
5471 )
5572 )
5673
74+ # Check header.
5775 header = FlatTensorHeader .from_bytes (data [0 : FlatTensorHeader .EXPECTED_LENGTH ])
5876 self .assertTrue (header .is_valid ())
5977
6078 self .assertEqual (header .flatbuffer_offset , 48 )
61- self .assertEqual (header .flatbuffer_size , 200 )
62- self .assertEqual (header .segment_base_offset , 256 )
63- self .assertEqual (header .data_size , 16 )
79+ self .assertEqual (header .flatbuffer_size , 288 )
80+ self .assertEqual (header .segment_base_offset , 336 )
81+ self .assertEqual (header .data_size , 48 )
6482
6583 self .assertEqual (
6684 data [header .flatbuffer_offset + 4 : header .flatbuffer_offset + 8 ], b"FT01"
6785 )
86+
87+ # Check flat tensor data.
88+ flat_tensor_bytes = data [
89+ header .flatbuffer_offset : header .flatbuffer_offset + header .flatbuffer_size
90+ ]
91+
92+ flat_tensor = _convert_to_flat_tensor (flat_tensor_bytes )
93+
94+ self .assertEqual (flat_tensor .version , 0 )
95+ self .assertEqual (flat_tensor .tensor_alignment , config .tensor_alignment )
96+
97+ tensors = flat_tensor .tensors
98+ self .assertEqual (len (tensors ), 3 )
99+ self .assertEqual (tensors [0 ].fully_qualified_name , "fqn1" )
100+ self .check_tensor_metadata (TEST_TENSOR_LAYOUT ["fqn1" ], tensors [0 ])
101+ self .assertEqual (tensors [0 ].segment_index , 0 )
102+ self .assertEqual (tensors [0 ].offset , 0 )
103+
104+ self .assertEqual (tensors [1 ].fully_qualified_name , "fqn2" )
105+ self .check_tensor_metadata (TEST_TENSOR_LAYOUT ["fqn2" ], tensors [1 ])
106+ self .assertEqual (tensors [1 ].segment_index , 0 )
107+ self .assertEqual (tensors [1 ].offset , 0 )
108+
109+ self .assertEqual (tensors [2 ].fully_qualified_name , "fqn3" )
110+ self .check_tensor_metadata (TEST_TENSOR_LAYOUT ["fqn3" ], tensors [2 ])
111+ self .assertEqual (tensors [2 ].segment_index , 0 )
112+ self .assertEqual (tensors [2 ].offset , config .tensor_alignment )
113+
114+ segments = flat_tensor .segments
115+ self .assertEqual (len (segments ), 1 )
116+ self .assertEqual (segments [0 ].offset , 0 )
117+ self .assertEqual (segments [0 ].size , config .tensor_alignment * 3 )
118+
119+ # Check segment data.
120+ segment_data = data [
121+ header .segment_base_offset : header .segment_base_offset + segments [0 ].size
122+ ]
123+
124+ t0_start = 0
125+ t0_len = len (TEST_TENSOR_BUFFER [0 ])
126+ t0_end = config .tensor_alignment
127+ self .assertEqual (
128+ segment_data [t0_start : t0_start + t0_len ], TEST_TENSOR_BUFFER [0 ]
129+ )
130+ padding = b"\x00 " * (t0_end - t0_len )
131+ self .assertEqual (segment_data [t0_start + t0_len : t0_end ], padding )
132+
133+ t1_start = config .tensor_alignment
134+ t1_len = len (TEST_TENSOR_BUFFER [1 ])
135+ t1_end = config .tensor_alignment * 3
136+ self .assertEqual (
137+ segment_data [t1_start : t1_start + t1_len ],
138+ TEST_TENSOR_BUFFER [1 ],
139+ )
140+ padding = b"\x00 " * (t1_end - (t1_len + t1_start ))
141+ self .assertEqual (segment_data [t1_start + t1_len : t1_end ], padding )
0 commit comments