@@ -28,32 +28,36 @@ using torch::executor::util::FileDataLoader;
2828
2929class FlatTensorDataMapTest : public ::testing::Test {
3030 protected:
31+ void create_loader (const char * path, const char * module_name) {
32+ // Create a loader for the serialized data map.
33+ Result<FileDataLoader> loader = FileDataLoader::from (path);
34+ ASSERT_EQ (loader.error (), Error::Ok);
35+ loaders_.insert (
36+ {module_name,
37+ std::make_unique<FileDataLoader>(std::move (loader.get ()))});
38+ }
3139 void SetUp () override {
3240 // Since these tests cause ET_LOG to be called, the PAL must be initialized
3341 // first.
3442 executorch::runtime::runtime_init ();
3543
36- // Load data map. The eager linear model is defined at:
37- // //executorch/test/models/linear_model.py
38- const char * path = std::getenv (" ET_MODULE_LINEAR_DATA_PATH" );
39- Result<FileDataLoader> loader = FileDataLoader::from (path);
40- ASSERT_EQ (loader.error (), Error::Ok);
41-
42- data_map_loader_ =
43- std::make_unique<FileDataLoader>(std::move (loader.get ()));
44+ // Model defined in //executorch/test/models/linear_model.py
45+ create_loader (std::getenv (" ET_MODULE_LINEAR_DATA_PATH" ), " linear" );
46+ // Model defined in //executorch/test/models/export_delegated_program.py
47+ create_loader (std::getenv (" ET_MODULE_LINEAR_XNN_DATA_PATH" ), " linear_xnn" );
4448 }
45- std::unique_ptr<FileDataLoader> data_map_loader_ ;
49+ std::unordered_map<std::string, std:: unique_ptr<FileDataLoader>> loaders_ ;
4650};
4751
4852TEST_F (FlatTensorDataMapTest, LoadFlatTensorDataMap) {
4953 Result<FlatTensorDataMap> data_map =
50- FlatTensorDataMap::load (data_map_loader_ .get ());
54+ FlatTensorDataMap::load (loaders_[ " linear " ] .get ());
5155 EXPECT_EQ (data_map.error (), Error::Ok);
5256}
5357
5458TEST_F (FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
5559 Result<FlatTensorDataMap> data_map =
56- FlatTensorDataMap::load (data_map_loader_ .get ());
60+ FlatTensorDataMap::load (loaders_[ " linear " ] .get ());
5761 EXPECT_EQ (data_map.error (), Error::Ok);
5862
5963 // Check tensor layouts are correct.
@@ -95,7 +99,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
9599
96100TEST_F (FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
97101 Result<FlatTensorDataMap> data_map =
98- FlatTensorDataMap::load (data_map_loader_ .get ());
102+ FlatTensorDataMap::load (loaders_[ " linear " ] .get ());
99103 EXPECT_EQ (data_map.error (), Error::Ok);
100104
101105 // Check tensor data sizes are correct.
@@ -116,7 +120,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
116120
117121TEST_F (FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
118122 Result<FlatTensorDataMap> data_map =
119- FlatTensorDataMap::load (data_map_loader_ .get ());
123+ FlatTensorDataMap::load (loaders_[ " linear " ] .get ());
120124 EXPECT_EQ (data_map.error (), Error::Ok);
121125
122126 // Check num tensors is 2.
@@ -140,7 +144,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
140144
141145TEST_F (FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
142146 Result<FlatTensorDataMap> data_map =
143- FlatTensorDataMap::load (data_map_loader_ .get ());
147+ FlatTensorDataMap::load (loaders_[ " linear " ] .get ());
144148 EXPECT_EQ (data_map.error (), Error::Ok);
145149
146150 // get the metadata
@@ -160,3 +164,62 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
160164 }
161165 free (data);
162166}
167+
168+ TEST_F (FlatTensorDataMapTest, FlatTensorDataMap_GetData_Xnnpack) {
169+ Result<FlatTensorDataMap> data_map =
170+ FlatTensorDataMap::load (loaders_[" linear_xnn" ].get ());
171+ EXPECT_EQ (data_map.error (), Error::Ok);
172+
173+ // Check tensor data sizes are correct.
174+ // 64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885 is the
175+ // hash of the 3*3 identity matrix
176+ Result<FreeableBuffer> data_weight_res = data_map->get_data (
177+ " 64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885" );
178+ ASSERT_EQ (Error::Ok, data_weight_res.error ());
179+ FreeableBuffer data_a = std::move (data_weight_res.get ());
180+ EXPECT_EQ (data_a.size (), 36 ); // 3*3*4 (3*3 matrix, 4 bytes per float)
181+
182+ // 15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b is the
183+ // hash of the 3*1 vector [1, 1, 1]
184+ Result<FreeableBuffer> data_bias_res = data_map->get_data (
185+ " 15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b" );
186+ ASSERT_EQ (Error::Ok, data_bias_res.error ());
187+ FreeableBuffer data_b = std::move (data_bias_res.get ());
188+ EXPECT_EQ (data_b.size (), 12 ); // 3*4 (3*1 vector, 4 bytes per float)
189+
190+ // Check get_data fails when key is not found.
191+ Result<FreeableBuffer> data_c_res = data_map->get_data (" c" );
192+ EXPECT_EQ (data_c_res.error (), Error::NotFound);
193+ }
194+
195+ TEST_F (FlatTensorDataMapTest, FlatTensorDataMap_Keys_Xnnpack) {
196+ Result<FlatTensorDataMap> data_map =
197+ FlatTensorDataMap::load (loaders_[" linear_xnn" ].get ());
198+ EXPECT_EQ (data_map.error (), Error::Ok);
199+
200+ // Check num tensors is 2.
201+ Result<size_t > num_tensors_res = data_map->get_num_keys ();
202+ ASSERT_EQ (Error::Ok, num_tensors_res.error ());
203+ EXPECT_EQ (num_tensors_res.get (), 2 );
204+
205+ // Check get_key returns the correct keys.
206+ Result<const char *> key0_res = data_map->get_key (0 );
207+ ASSERT_EQ (Error::Ok, key0_res.error ());
208+ EXPECT_EQ (
209+ strcmp (
210+ key0_res.get (),
211+ " 64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885" ),
212+ 0 );
213+
214+ Result<const char *> key1_res = data_map->get_key (1 );
215+ ASSERT_EQ (Error::Ok, key1_res.error ());
216+ EXPECT_EQ (
217+ strcmp (
218+ key1_res.get (),
219+ " 15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b" ),
220+ 0 );
221+
222+ // Check get_key fails when out of bounds.
223+ Result<const char *> key2_res = data_map->get_key (2 );
224+ EXPECT_EQ (key2_res.error (), Error::InvalidArgument);
225+ }
0 commit comments