Skip to content

Commit a0beed4

Browse files
committed
Add named data map merge
Add merge functionality to named_data_map interface and flat_tensor_data_map. pte_data_map does not implement merge. Differential Revision: [D76351013](https://our.internmc.facebook.com/intern/diff/D76351013/) [ghstack-poisoned]
1 parent cbd3874 commit a0beed4

File tree

6 files changed

+248
-87
lines changed

6 files changed

+248
-87
lines changed

extension/flat_tensor/flat_tensor_data_map.cpp

Lines changed: 130 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
#include <executorch/runtime/core/span.h>
2020
#include <executorch/runtime/platform/compiler.h>
2121

22+
using executorch::aten::ScalarType;
23+
using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap;
24+
using executorch::ET_RUNTIME_NAMESPACE::TensorLayout;
25+
using executorch::runtime::DataLoader;
2226
using executorch::runtime::Error;
2327
using executorch::runtime::FreeableBuffer;
2428
using executorch::runtime::Result;
2529
using executorch::runtime::Span;
2630

27-
using executorch::aten::ScalarType;
28-
using executorch::ET_RUNTIME_NAMESPACE::TensorLayout;
29-
using executorch::runtime::DataLoader;
30-
3131
namespace executorch {
3232
namespace extension {
3333

@@ -103,82 +103,109 @@ Result<const TensorLayout> create_tensor_layout(
103103

104104
ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_tensor_layout(
105105
executorch::aten::string_view key) const {
106-
Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data(
107-
key,
108-
flat_tensor_->named_data(),
109-
flat_tensor_->segments(),
110-
header_.segment_base_offset + header_.segment_data_size);
111-
if (!named_data.ok()) {
106+
if (key_to_map_index_.find(key.data()) == key_to_map_index_.end()) {
107+
return Error::NotFound;
108+
}
109+
auto index = key_to_map_index_.at(key.data());
110+
if (index == -1) {
111+
Result<const flat_tensor_flatbuffer::NamedData*> named_data =
112+
get_named_data(
113+
key,
114+
flat_tensor_->named_data(),
115+
flat_tensor_->segments(),
116+
header_.segment_base_offset + header_.segment_data_size);
117+
if (named_data.ok()) {
118+
return create_tensor_layout(named_data.get()->tensor_layout());
119+
}
112120
return named_data.error();
121+
} else {
122+
return merged_maps_[index]->get_tensor_layout(key);
113123
}
114-
return create_tensor_layout(named_data.get()->tensor_layout());
115124
}
116125

117126
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
118127
executorch::aten::string_view key) const {
119-
Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data(
120-
key,
121-
flat_tensor_->named_data(),
122-
flat_tensor_->segments(),
123-
header_.segment_base_offset + header_.segment_data_size);
124-
if (!named_data.ok()) {
128+
if (key_to_map_index_.find(key.data()) == key_to_map_index_.end()) {
129+
return Error::NotFound;
130+
}
131+
auto index = key_to_map_index_.at(key.data());
132+
if (index == -1) {
133+
Result<const flat_tensor_flatbuffer::NamedData*> named_data =
134+
get_named_data(
135+
key,
136+
flat_tensor_->named_data(),
137+
flat_tensor_->segments(),
138+
header_.segment_base_offset + header_.segment_data_size);
139+
if (named_data.ok()) {
140+
uint32_t segment_index = named_data.get()->segment_index();
141+
uint64_t segment_offset =
142+
flat_tensor_->segments()->Get(segment_index)->offset();
143+
uint64_t segment_size =
144+
flat_tensor_->segments()->Get(segment_index)->size();
145+
146+
return loader_->load(
147+
/*offset=*/header_.segment_base_offset + segment_offset,
148+
segment_size,
149+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
150+
}
125151
return named_data.error();
152+
} else {
153+
return merged_maps_[index]->get_data(key);
126154
}
127-
128-
uint32_t segment_index = named_data.get()->segment_index();
129-
uint64_t segment_offset =
130-
flat_tensor_->segments()->Get(segment_index)->offset();
131-
uint64_t segment_size = flat_tensor_->segments()->Get(segment_index)->size();
132-
133-
return loader_->load(
134-
/*offset=*/header_.segment_base_offset + segment_offset,
135-
segment_size,
136-
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
137155
}
138156

139157
ET_NODISCARD Error FlatTensorDataMap::load_data_into(
140158
ET_UNUSED executorch::aten::string_view key,
141159
ET_UNUSED void* buffer,
142160
ET_UNUSED size_t size) const {
143-
Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data(
144-
key,
145-
flat_tensor_->named_data(),
146-
flat_tensor_->segments(),
147-
header_.segment_base_offset + header_.segment_data_size);
148-
if (!named_data.ok()) {
149-
return named_data.error();
161+
if (key_to_map_index_.find(key.data()) == key_to_map_index_.end()) {
162+
return Error::NotFound;
150163
}
164+
auto index = key_to_map_index_.at(key.data());
165+
if (index == -1) {
166+
Result<const flat_tensor_flatbuffer::NamedData*> named_data =
167+
get_named_data(
168+
key,
169+
flat_tensor_->named_data(),
170+
flat_tensor_->segments(),
171+
header_.segment_base_offset + header_.segment_data_size);
172+
if (!named_data.ok()) {
173+
return named_data.error();
174+
}
151175

152-
uint32_t segment_index = named_data.get()->segment_index();
153-
uint64_t segment_offset =
154-
flat_tensor_->segments()->Get(segment_index)->offset();
176+
uint32_t segment_index = named_data.get()->segment_index();
177+
uint64_t segment_offset =
178+
flat_tensor_->segments()->Get(segment_index)->offset();
155179

156-
Result<const TensorLayout> tensor_layout =
157-
create_tensor_layout(named_data.get()->tensor_layout());
180+
Result<const TensorLayout> tensor_layout =
181+
create_tensor_layout(named_data.get()->tensor_layout());
158182

159-
if (!tensor_layout.ok()) {
160-
return tensor_layout.error();
161-
}
183+
if (!tensor_layout.ok()) {
184+
return tensor_layout.error();
185+
}
162186

163-
ET_CHECK_OR_RETURN_ERROR(
164-
size <= tensor_layout.get().nbytes(),
165-
InvalidArgument,
166-
"Buffer size %zu is smaller than tensor size %zu",
167-
size,
168-
tensor_layout.get().nbytes());
169-
170-
// Load mutable data.
171-
DataLoader::SegmentInfo info = DataLoader::SegmentInfo(
172-
DataLoader::SegmentInfo::Type::Mutable, 0, nullptr);
173-
return loader_->load_into(
174-
header_.segment_base_offset + segment_offset,
175-
tensor_layout.get().nbytes(),
176-
info,
177-
buffer);
187+
ET_CHECK_OR_RETURN_ERROR(
188+
size <= tensor_layout.get().nbytes(),
189+
InvalidArgument,
190+
"Buffer size %zu is smaller than tensor size %zu",
191+
size,
192+
tensor_layout.get().nbytes());
193+
194+
// Load mutable data.
195+
DataLoader::SegmentInfo info = DataLoader::SegmentInfo(
196+
DataLoader::SegmentInfo::Type::Mutable, 0, nullptr);
197+
return loader_->load_into(
198+
header_.segment_base_offset + segment_offset,
199+
tensor_layout.get().nbytes(),
200+
info,
201+
buffer);
202+
} else {
203+
return merged_maps_[index]->load_data_into(key, buffer, size);
204+
}
178205
}
179206

180207
ET_NODISCARD Result<uint32_t> FlatTensorDataMap::get_num_keys() const {
181-
return flat_tensor_->named_data()->size();
208+
return key_to_map_index_.size();
182209
}
183210

184211
ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
@@ -190,7 +217,40 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
190217
"Index %u out of range of size %u",
191218
index,
192219
num_keys);
193-
return flat_tensor_->named_data()->Get(index)->key()->c_str();
220+
221+
uint32_t current_index = 0;
222+
for (const auto& pair : key_to_map_index_) {
223+
if (current_index == index) {
224+
return pair.first.c_str();
225+
}
226+
current_index++;
227+
}
228+
return Error::NotFound;
229+
}
230+
231+
ET_NODISCARD Error FlatTensorDataMap::merge(const NamedDataMap* other) {
232+
ET_CHECK_OR_RETURN_ERROR(
233+
other != nullptr, InvalidArgument, "Merge error: other is nullptr.");
234+
235+
// Check if any duplicate keys exist.
236+
uint32_t num_keys = other->get_num_keys().get();
237+
238+
for (uint32_t i = 0; i < num_keys; i++) {
239+
const char* key = other->get_key(i).get();
240+
ET_CHECK_OR_RETURN_ERROR(
241+
key_to_map_index_.find(key) == key_to_map_index_.end(),
242+
InvalidArgument,
243+
"Merge error: key %s already exists in the named_data_map.",
244+
key);
245+
}
246+
// Place keys into the map.
247+
for (uint32_t i = 0; i < num_keys; i++) {
248+
const char* key = other->get_key(i).get();
249+
key_to_map_index_[key] = merged_maps_.size();
250+
}
251+
252+
merged_maps_.push_back(other);
253+
return Error::Ok;
194254
}
195255

196256
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(
@@ -261,8 +321,18 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
261321
InvalidExternalData,
262322
"FlatTensor segments is nullptr, malformed PTD file.");
263323

324+
// Add keys to the map.
325+
std::unordered_map<std::string, int32_t> key_to_map_index;
326+
for (int i = 0; i < flat_tensor->named_data()->size(); i++) {
327+
const auto* named_data = flat_tensor->named_data()->Get(i);
328+
key_to_map_index[named_data->key()->c_str()] = -1;
329+
}
264330
return FlatTensorDataMap(
265-
fh.get(), std::move(flat_tensor_data.get()), flat_tensor, loader);
331+
fh.get(),
332+
std::move(flat_tensor_data.get()),
333+
flat_tensor,
334+
loader,
335+
std::move(key_to_map_index));
266336
}
267337

268338
} // namespace extension

extension/flat_tensor/flat_tensor_data_map.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,17 @@ class FlatTensorDataMap final
9494
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
9595
uint32_t index) const override;
9696

97+
/**
98+
* Merge a named_data_map into the current one.
99+
* @param[in] other The named_data_map to merge.
100+
* @return Error indicating if the merge was successful or not.
101+
*
102+
* Note: The FlatTensorDataMap does not perform a deep copy; it holds a
103+
* reference to other, so other must outlive the FlatTensorDataMap instance.
104+
*/
105+
ET_NODISCARD executorch::runtime::Error merge(
106+
const NamedDataMap* other) override;
107+
97108
FlatTensorDataMap(FlatTensorDataMap&&) noexcept = default;
98109

99110
~FlatTensorDataMap() override = default;
@@ -103,11 +114,14 @@ class FlatTensorDataMap final
103114
const FlatTensorHeader& header,
104115
executorch::runtime::FreeableBuffer&& flat_tensor_data,
105116
const flat_tensor_flatbuffer::FlatTensor* flat_tensor,
106-
executorch::runtime::DataLoader* loader)
117+
executorch::runtime::DataLoader* loader,
118+
std::unordered_map<std::string, int32_t> key_to_map_index)
107119
: header_(header),
108120
flat_tensor_data_(std::move(flat_tensor_data)),
109121
flat_tensor_(flat_tensor),
110-
loader_(loader) {}
122+
loader_(loader),
123+
key_to_map_index_(std::move(key_to_map_index)),
124+
merged_maps_({}) {}
111125

112126
// Not copyable or assignable.
113127
FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete;
@@ -125,6 +139,13 @@ class FlatTensorDataMap final
125139

126140
// Data loader, used to load segment data.
127141
executorch::runtime::DataLoader* loader_;
142+
143+
// Cache of keys to data map index.
144+
// index=-1 is used for the flat_tensor data map.
145+
std::unordered_map<std::string, int32_t> key_to_map_index_;
146+
147+
// Other NamedDataMaps.
148+
std::vector<const NamedDataMap*> merged_maps_;
128149
};
129150

130151
} // namespace extension

0 commit comments

Comments
 (0)