@@ -32,29 +32,41 @@ class DataMapTest : public ::testing::Test {
3232 // Since these tests cause ET_LOG to be called, the PAL must be initialized
3333 // first.
3434 executorch::runtime::runtime_init ();
35- }
36- };
3735
38- TEST_F (DataMapTest, LoadDataMap) {
39- const char * path = std::getenv (" ET_MODULE_LINEAR_DATA" );
40- Result<FileDataLoader> loader = FileDataLoader::from (path);
41- ASSERT_EQ (loader.error (), Error::Ok);
36+ // Load data map.
37+ // The eager linear model is defined at:
38+ // //executorch/test/models/linear_model.py
39+ const char * path = std::getenv (" ET_MODULE_LINEAR_DATA" );
40+ Result<FileDataLoader> loader = FileDataLoader::from (path);
41+ ASSERT_EQ (loader.error (), Error::Ok);
4242
43- Result<FreeableBuffer> header = loader->load (
44- /* offset=*/ 0 ,
45- FlatTensorHeader::kNumHeadBytes ,
46- /* segment_info=*/
47- DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
43+ Result<FreeableBuffer> header = loader->load (
44+ /* offset=*/ 0 ,
45+ FlatTensorHeader::kNumHeadBytes ,
46+ /* segment_info=*/
47+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
4848
49- ASSERT_EQ (header.error (), Error::Ok);
49+ ASSERT_EQ (header.error (), Error::Ok);
50+
51+ data_map_loader_ =
52+ std::make_unique<FileDataLoader>(std::move (loader.get ()));
53+ }
54+ std::unique_ptr<FileDataLoader> data_map_loader_;
55+ };
5056
51- auto data_map_loader_ =
52- std::make_unique<FileDataLoader>(std::move (loader.get ()));
57+ TEST_F (DataMapTest, LoadDataMap) {
58+ Result<DataMap> data_map = DataMap::load (data_map_loader_.get ());
59+ EXPECT_EQ (data_map.error (), Error::Ok);
60+ }
5361
62+ TEST_F (DataMapTest, DataMap_GetMetadata) {
5463 Result<DataMap> data_map = DataMap::load (data_map_loader_.get ());
5564 EXPECT_EQ (data_map.error (), Error::Ok);
5665
57- // Check tensor metadata.
66+ // Check tensor layouts are correct.
67+ // From //executorch/test/models/linear_model.py, we have the tensors
68+ // self.a = 3 * torch.ones(2, 2, dtype=torch.float)
69+ // self.b = 2 * torch.ones(2, 2, dtype=torch.float)
5870 Result<const TensorLayout> const_a_res = data_map->get_metadata (" a" );
5971 assert (const_a_res.ok ());
6072
@@ -83,16 +95,50 @@ TEST_F(DataMapTest, LoadDataMap) {
8395 EXPECT_EQ (dim_order_b[0 ], 0 );
8496 EXPECT_EQ (dim_order_b[1 ], 1 );
8597
86- // Check tensor data.
98+ // Check get_metadata fails when key is not found.
99+ Result<const TensorLayout> const_c_res = data_map->get_metadata (" c" );
100+ EXPECT_EQ (const_c_res.error (), Error::InvalidArgument);
101+ }
102+
103+ TEST_F (DataMapTest, DataMap_GetData) {
104+ Result<DataMap> data_map = DataMap::load (data_map_loader_.get ());
105+ EXPECT_EQ (data_map.error (), Error::Ok);
106+
107+ // Check tensor data sizes are correct.
87108 Result<FreeableBuffer> data_a_res = data_map->get_data (" a" );
88109 assert (data_a_res.ok ());
89- // Check we have the correct tensor data.
90110 FreeableBuffer data_a = std::move (data_a_res.get ());
91111 EXPECT_EQ (data_a.size (), 16 );
92112
93113 Result<FreeableBuffer> data_b_res = data_map->get_data (" b" );
94114 assert (data_b_res.ok ());
95- // Check we have the correct tensor data.
96115 FreeableBuffer data_b = std::move (data_b_res.get ());
97116 EXPECT_EQ (data_b.size (), 16 );
117+
118+ // Check get_data fails when key is not found.
119+ Result<FreeableBuffer> data_c_res = data_map->get_data (" c" );
120+ EXPECT_EQ (data_c_res.error (), Error::InvalidArgument);
121+ }
122+
123+ TEST_F (DataMapTest, DataMap_Keys) {
124+ Result<DataMap> data_map = DataMap::load (data_map_loader_.get ());
125+ EXPECT_EQ (data_map.error (), Error::Ok);
126+
127+ // Check num tensors is 2.
128+ Result<size_t > num_tensors_res = data_map->get_num_keys ();
129+ assert (num_tensors_res.ok ());
130+ EXPECT_EQ (num_tensors_res.get (), 2 );
131+
132+ // Check get_key returns the correct keys.
133+ Result<const char *> key0_res = data_map->get_key (0 );
134+ assert (key0_res.ok ());
135+ EXPECT_EQ (strcmp (key0_res.get (), " b" ), 0 );
136+
137+ Result<const char *> key1_res = data_map->get_key (1 );
138+ assert (key1_res.ok ());
139+ EXPECT_EQ (strcmp (key1_res.get (), " a" ), 0 );
140+
141+ // Check get_key fails when out of bounds.
142+ Result<const char *> key2_res = data_map->get_key (2 );
143+ EXPECT_EQ (key2_res.error (), Error::InvalidArgument);
98144}
0 commit comments