Skip to content

Commit 54c2137

Browse files
committed
Update on "[executorch][flat_tensor] DataMap implementation"
DataMap implementation that * Loads a flat_tensor file * Populates a map with {fqn: tensor} and {fqn: TensorLayout}. * Makes tensor information available via the named_data_map.h interface. For now, DataMap doesn't store the DataLoader. - If/when tensors are in their own segments, DataMap should also store a DataLoader. Differential Revision: [D67064580](https://our.internmc.facebook.com/intern/diff/D67064580/) [ghstack-poisoned]
2 parents eb49548 + c74a135 commit 54c2137

File tree

4 files changed

+86
-36
lines changed

4 files changed

+86
-36
lines changed

extension/flat_tensor/named_data_map/data_map.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,16 @@ ET_NODISCARD Result<const TensorLayout> DataMap::get_metadata(
4949
const char* fqn) const {
5050
auto result = _name_to_tensor.find(fqn);
5151
if (result == _name_to_tensor.end()) {
52-
return Error::NotFound;
52+
return Error::InvalidArgument;
5353
}
54+
// value is a tuple of (segment_index, offset, tensor_layout)
5455
return std::get<2>(result->second);
5556
}
5657

5758
ET_NODISCARD Result<FreeableBuffer> DataMap::get_data(const char* fqn) const {
5859
auto result = _name_to_tensor.find(fqn);
5960
if (result == _name_to_tensor.end()) {
60-
return Error::NotFound;
61+
return Error::InvalidArgument;
6162
}
6263
int offset = std::get<1>(result->second);
6364
TensorLayout tensor = std::get<2>(result->second);
@@ -66,17 +67,17 @@ ET_NODISCARD Result<FreeableBuffer> DataMap::get_data(const char* fqn) const {
6667
return FreeableBuffer(data, tensor.nbytes(), nullptr);
6768
}
6869

69-
ET_NODISCARD Error
70+
ET_NODISCARD Result<size_t>
7071
DataMap::load_data_into(const char* fqn, size_t size, void* buffer) const {
7172
return Error::NotImplemented;
7273
}
7374

74-
ET_NODISCARD Result<int> DataMap::get_num_keys() const {
75+
ET_NODISCARD Result<size_t> DataMap::get_num_keys() const {
7576
return _name_to_tensor.size();
7677
}
7778

78-
ET_NODISCARD Result<const char*> DataMap::get_key(int index) const {
79-
if (index <= 0 || index >= _name_to_tensor.size()) {
79+
ET_NODISCARD Result<const char*> DataMap::get_key(size_t index) const {
80+
if (index < 0 || index >= _name_to_tensor.size()) {
8081
return Error::InvalidArgument;
8182
}
8283

@@ -158,7 +159,7 @@ ET_NODISCARD Result<const char*> DataMap::get_key(int index) const {
158159
assert(s_tensor_metadata != nullptr);
159160

160161
std::unordered_map<std::string, std::tuple<int, int, TensorLayout>>
161-
fqn_to_tensor_layout = {};
162+
name_to_tensor = {};
162163
for (int i = 0; i < s_tensor_metadata->size(); i++) {
163164
// Create TensorLayouts.
164165
ScalarType scalar_type =
@@ -178,7 +179,7 @@ ET_NODISCARD Result<const char*> DataMap::get_key(int index) const {
178179
std::string fqn = s_tensor_metadata->Get(i)->fully_qualified_name()->str();
179180

180181
auto val = std::make_tuple(segment_index, offset, tensor_layout);
181-
fqn_to_tensor_layout.insert({fqn, std::move(val)});
182+
name_to_tensor.insert({fqn, std::move(val)});
182183
}
183184

184185
// Load constant data.
@@ -203,7 +204,7 @@ ET_NODISCARD Result<const char*> DataMap::get_key(int index) const {
203204

204205
return DataMap(
205206
std::move(flat_tensor_data.get()),
206-
std::move(fqn_to_tensor_layout),
207+
std::move(name_to_tensor),
207208
std::move(_data_ro.get()));
208209
}
209210

extension/flat_tensor/named_data_map/data_map.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,13 @@ class DataMap final : public executorch::runtime::NamedDataMap {
3838
ET_NODISCARD
3939
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
4040
const char* fqn) const override;
41-
ET_NODISCARD runtime::Error
41+
ET_NODISCARD executorch::runtime::Result<size_t>
4242
load_data_into(const char* fqn, size_t size, void* buffer) const override;
4343

44-
ET_NODISCARD executorch::runtime::Result<int> get_num_keys() const override;
44+
ET_NODISCARD executorch::runtime::Result<size_t> get_num_keys()
45+
const override;
4546
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
46-
int index) const override;
47+
size_t index) const override;
4748

4849
DataMap(DataMap&&) noexcept = default;
4950
~DataMap() override;

extension/flat_tensor/test/data_map_test.cpp

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,29 +32,41 @@ class DataMapTest : public ::testing::Test {
3232
// Since these tests cause ET_LOG to be called, the PAL must be initialized
3333
// first.
3434
executorch::runtime::runtime_init();
35-
}
36-
};
3735

38-
TEST_F(DataMapTest, LoadDataMap) {
39-
const char* path = std::getenv("ET_MODULE_LINEAR_DATA");
40-
Result<FileDataLoader> loader = FileDataLoader::from(path);
41-
ASSERT_EQ(loader.error(), Error::Ok);
36+
// Load data map.
37+
// The eager linear model is defined at:
38+
// //executorch/test/models/linear_model.py
39+
const char* path = std::getenv("ET_MODULE_LINEAR_DATA");
40+
Result<FileDataLoader> loader = FileDataLoader::from(path);
41+
ASSERT_EQ(loader.error(), Error::Ok);
4242

43-
Result<FreeableBuffer> header = loader->load(
44-
/*offset=*/0,
45-
FlatTensorHeader::kNumHeadBytes,
46-
/*segment_info=*/
47-
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
43+
Result<FreeableBuffer> header = loader->load(
44+
/*offset=*/0,
45+
FlatTensorHeader::kNumHeadBytes,
46+
/*segment_info=*/
47+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
4848

49-
ASSERT_EQ(header.error(), Error::Ok);
49+
ASSERT_EQ(header.error(), Error::Ok);
50+
51+
data_map_loader_ =
52+
std::make_unique<FileDataLoader>(std::move(loader.get()));
53+
}
54+
std::unique_ptr<FileDataLoader> data_map_loader_;
55+
};
5056

51-
auto data_map_loader_ =
52-
std::make_unique<FileDataLoader>(std::move(loader.get()));
57+
TEST_F(DataMapTest, LoadDataMap) {
58+
Result<DataMap> data_map = DataMap::load(data_map_loader_.get());
59+
EXPECT_EQ(data_map.error(), Error::Ok);
60+
}
5361

62+
TEST_F(DataMapTest, DataMap_GetMetadata) {
5463
Result<DataMap> data_map = DataMap::load(data_map_loader_.get());
5564
EXPECT_EQ(data_map.error(), Error::Ok);
5665

57-
// Check tensor metadata.
66+
// Check tensor layouts are correct.
67+
// From //executorch/test/models/linear_model.py, we have the tensors
68+
// self.a = 3 * torch.ones(2, 2, dtype=torch.float)
69+
// self.b = 2 * torch.ones(2, 2, dtype=torch.float)
5870
Result<const TensorLayout> const_a_res = data_map->get_metadata("a");
5971
assert(const_a_res.ok());
6072

@@ -83,16 +95,50 @@ TEST_F(DataMapTest, LoadDataMap) {
8395
EXPECT_EQ(dim_order_b[0], 0);
8496
EXPECT_EQ(dim_order_b[1], 1);
8597

86-
// Check tensor data.
98+
// Check get_metadata fails when key is not found.
99+
Result<const TensorLayout> const_c_res = data_map->get_metadata("c");
100+
EXPECT_EQ(const_c_res.error(), Error::InvalidArgument);
101+
}
102+
103+
TEST_F(DataMapTest, DataMap_GetData) {
104+
Result<DataMap> data_map = DataMap::load(data_map_loader_.get());
105+
EXPECT_EQ(data_map.error(), Error::Ok);
106+
107+
// Check tensor data sizes are correct.
87108
Result<FreeableBuffer> data_a_res = data_map->get_data("a");
88109
assert(data_a_res.ok());
89-
// Check we have the correct tensor data.
90110
FreeableBuffer data_a = std::move(data_a_res.get());
91111
EXPECT_EQ(data_a.size(), 16);
92112

93113
Result<FreeableBuffer> data_b_res = data_map->get_data("b");
94114
assert(data_b_res.ok());
95-
// Check we have the correct tensor data.
96115
FreeableBuffer data_b = std::move(data_b_res.get());
97116
EXPECT_EQ(data_b.size(), 16);
117+
118+
// Check get_data fails when key is not found.
119+
Result<FreeableBuffer> data_c_res = data_map->get_data("c");
120+
EXPECT_EQ(data_c_res.error(), Error::InvalidArgument);
121+
}
122+
123+
TEST_F(DataMapTest, DataMap_Keys) {
124+
Result<DataMap> data_map = DataMap::load(data_map_loader_.get());
125+
EXPECT_EQ(data_map.error(), Error::Ok);
126+
127+
// Check num tensors is 2.
128+
Result<size_t> num_tensors_res = data_map->get_num_keys();
129+
assert(num_tensors_res.ok());
130+
EXPECT_EQ(num_tensors_res.get(), 2);
131+
132+
// Check get_key returns the correct keys.
133+
Result<const char*> key0_res = data_map->get_key(0);
134+
assert(key0_res.ok());
135+
EXPECT_EQ(strcmp(key0_res.get(), "b"), 0);
136+
137+
Result<const char*> key1_res = data_map->get_key(1);
138+
assert(key1_res.ok());
139+
EXPECT_EQ(strcmp(key1_res.get(), "a"), 0);
140+
141+
// Check get_key fails when out of bounds.
142+
Result<const char*> key2_res = data_map->get_key(2);
143+
EXPECT_EQ(key2_res.error(), Error::InvalidArgument);
98144
}

runtime/core/named_data_map.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,28 +46,30 @@ class ET_EXPERIMENTAL NamedDataMap {
4646
* Loads data corresponding to the fqn into the provided buffer.
4747
*
4848
* @param fqn Fully qualified name of the tensor.
49-
* @param size The number of bytes to load.
49+
* @param size The number of bytes to load. Use `get_metadata` to retrieve the
50+
* size of the tensor for a given fqn.
5051
* @param buffer The buffer to load the data into. Must point to at least
5152
* `size` bytes of memory.
52-
* @return An error code on if the load was successful.
53+
* @return Result containing the number of bytes written on success.
5354
*/
54-
ET_NODISCARD virtual Error
55+
ET_NODISCARD virtual Result<size_t>
5556
load_data_into(const char* fqn, size_t size, void* buffer) const = 0;
5657

5758
/**
5859
* Get the number of keys in the NamedDataMap.
5960
*
6061
* @return Result containing the number of keys.
6162
*/
62-
ET_NODISCARD virtual Result<int> get_num_keys() const = 0;
63+
ET_NODISCARD virtual Result<size_t> get_num_keys() const = 0;
6364

6465
/**
6566
* Get the key at the given index.
6667
*
6768
* @param index The index of the key to retrieve.
68-
* @return Result containing the key at the given index.
69+
* @return Result containing the key at the given index. Note: the returned
70+
* pointer is only valid for the lifetime of the DataMap.
6971
*/
70-
ET_NODISCARD virtual Result<const char*> get_key(int index) const = 0;
72+
ET_NODISCARD virtual Result<const char*> get_key(size_t index) const = 0;
7173
};
7274

7375
} // namespace runtime

0 commit comments

Comments
 (0)