1919#include < executorch/runtime/core/span.h>
2020#include < executorch/runtime/platform/compiler.h>
2121
22+ using executorch::aten::ScalarType;
23+ using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap;
24+ using executorch::ET_RUNTIME_NAMESPACE::TensorLayout;
25+ using executorch::runtime::DataLoader;
2226using executorch::runtime::Error;
2327using executorch::runtime::FreeableBuffer;
2428using executorch::runtime::Result;
2529using executorch::runtime::Span;
2630
27- using executorch::aten::ScalarType;
28- using executorch::ET_RUNTIME_NAMESPACE::TensorLayout;
29- using executorch::runtime::DataLoader;
30-
3131namespace executorch {
3232namespace extension {
3333
@@ -103,82 +103,109 @@ Result<const TensorLayout> create_tensor_layout(
103103
104104ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_tensor_layout (
105105 executorch::aten::string_view key) const {
106- Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data (
107- key,
108- flat_tensor_->named_data (),
109- flat_tensor_->segments (),
110- header_.segment_base_offset + header_.segment_data_size );
111- if (!named_data.ok ()) {
106+ if (key_to_map_index_.find (key.data ()) == key_to_map_index_.end ()) {
107+ return Error::NotFound;
108+ }
109+ auto index = key_to_map_index_.at (key.data ());
110+ if (index == -1 ) {
111+ Result<const flat_tensor_flatbuffer::NamedData*> named_data =
112+ get_named_data (
113+ key,
114+ flat_tensor_->named_data (),
115+ flat_tensor_->segments (),
116+ header_.segment_base_offset + header_.segment_data_size );
117+ if (named_data.ok ()) {
118+ return create_tensor_layout (named_data.get ()->tensor_layout ());
119+ }
112120 return named_data.error ();
121+ } else {
122+ return merged_maps_[index]->get_tensor_layout (key);
113123 }
114- return create_tensor_layout (named_data.get ()->tensor_layout ());
115124}
116125
117126ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data (
118127 executorch::aten::string_view key) const {
119- Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data (
120- key,
121- flat_tensor_->named_data (),
122- flat_tensor_->segments (),
123- header_.segment_base_offset + header_.segment_data_size );
124- if (!named_data.ok ()) {
128+ if (key_to_map_index_.find (key.data ()) == key_to_map_index_.end ()) {
129+ return Error::NotFound;
130+ }
131+ auto index = key_to_map_index_.at (key.data ());
132+ if (index == -1 ) {
133+ Result<const flat_tensor_flatbuffer::NamedData*> named_data =
134+ get_named_data (
135+ key,
136+ flat_tensor_->named_data (),
137+ flat_tensor_->segments (),
138+ header_.segment_base_offset + header_.segment_data_size );
139+ if (named_data.ok ()) {
140+ uint32_t segment_index = named_data.get ()->segment_index ();
141+ uint64_t segment_offset =
142+ flat_tensor_->segments ()->Get (segment_index)->offset ();
143+ uint64_t segment_size =
144+ flat_tensor_->segments ()->Get (segment_index)->size ();
145+
146+ return loader_->load (
147+ /* offset=*/ header_.segment_base_offset + segment_offset,
148+ segment_size,
149+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
150+ }
125151 return named_data.error ();
152+ } else {
153+ return merged_maps_[index]->get_data (key);
126154 }
127-
128- uint32_t segment_index = named_data.get ()->segment_index ();
129- uint64_t segment_offset =
130- flat_tensor_->segments ()->Get (segment_index)->offset ();
131- uint64_t segment_size = flat_tensor_->segments ()->Get (segment_index)->size ();
132-
133- return loader_->load (
134- /* offset=*/ header_.segment_base_offset + segment_offset,
135- segment_size,
136- DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
137155}
138156
139157ET_NODISCARD Error FlatTensorDataMap::load_data_into (
140158 ET_UNUSED executorch::aten::string_view key,
141159 ET_UNUSED void * buffer,
142160 ET_UNUSED size_t size) const {
143- Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data (
144- key,
145- flat_tensor_->named_data (),
146- flat_tensor_->segments (),
147- header_.segment_base_offset + header_.segment_data_size );
148- if (!named_data.ok ()) {
149- return named_data.error ();
161+ if (key_to_map_index_.find (key.data ()) == key_to_map_index_.end ()) {
162+ return Error::NotFound;
150163 }
164+ auto index = key_to_map_index_.at (key.data ());
165+ if (index == -1 ) {
166+ Result<const flat_tensor_flatbuffer::NamedData*> named_data =
167+ get_named_data (
168+ key,
169+ flat_tensor_->named_data (),
170+ flat_tensor_->segments (),
171+ header_.segment_base_offset + header_.segment_data_size );
172+ if (!named_data.ok ()) {
173+ return named_data.error ();
174+ }
151175
152- uint32_t segment_index = named_data.get ()->segment_index ();
153- uint64_t segment_offset =
154- flat_tensor_->segments ()->Get (segment_index)->offset ();
176+ uint32_t segment_index = named_data.get ()->segment_index ();
177+ uint64_t segment_offset =
178+ flat_tensor_->segments ()->Get (segment_index)->offset ();
155179
156- Result<const TensorLayout> tensor_layout =
157- create_tensor_layout (named_data.get ()->tensor_layout ());
180+ Result<const TensorLayout> tensor_layout =
181+ create_tensor_layout (named_data.get ()->tensor_layout ());
158182
159- if (!tensor_layout.ok ()) {
160- return tensor_layout.error ();
161- }
183+ if (!tensor_layout.ok ()) {
184+ return tensor_layout.error ();
185+ }
162186
163- ET_CHECK_OR_RETURN_ERROR (
164- size <= tensor_layout.get ().nbytes (),
165- InvalidArgument,
166- " Buffer size %zu is smaller than tensor size %zu" ,
167- size,
168- tensor_layout.get ().nbytes ());
169-
170- // Load mutable data.
171- DataLoader::SegmentInfo info = DataLoader::SegmentInfo (
172- DataLoader::SegmentInfo::Type::Mutable, 0 , nullptr );
173- return loader_->load_into (
174- header_.segment_base_offset + segment_offset,
175- tensor_layout.get ().nbytes (),
176- info,
177- buffer);
187+ ET_CHECK_OR_RETURN_ERROR (
188+ size <= tensor_layout.get ().nbytes (),
189+ InvalidArgument,
190+ " Buffer size %zu is smaller than tensor size %zu" ,
191+ size,
192+ tensor_layout.get ().nbytes ());
193+
194+ // Load mutable data.
195+ DataLoader::SegmentInfo info = DataLoader::SegmentInfo (
196+ DataLoader::SegmentInfo::Type::Mutable, 0 , nullptr );
197+ return loader_->load_into (
198+ header_.segment_base_offset + segment_offset,
199+ tensor_layout.get ().nbytes (),
200+ info,
201+ buffer);
202+ } else {
203+ return merged_maps_[index]->load_data_into (key, buffer, size);
204+ }
178205}
179206
180207ET_NODISCARD Result<uint32_t > FlatTensorDataMap::get_num_keys () const {
181- return flat_tensor_-> named_data ()-> size ();
208+ return key_to_map_index_. size ();
182209}
183210
184211ET_NODISCARD Result<const char *> FlatTensorDataMap::get_key (
@@ -190,7 +217,40 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
190217 " Index %u out of range of size %u" ,
191218 index,
192219 num_keys);
193- return flat_tensor_->named_data ()->Get (index)->key ()->c_str ();
220+
221+ uint32_t current_index = 0 ;
222+ for (const auto & pair : key_to_map_index_) {
223+ if (current_index == index) {
224+ return pair.first .c_str ();
225+ }
226+ current_index++;
227+ }
228+ return Error::NotFound;
229+ }
230+
231+ ET_NODISCARD Error FlatTensorDataMap::merge (const NamedDataMap* other) {
232+ ET_CHECK_OR_RETURN_ERROR (
233+ other != nullptr , InvalidArgument, " Merge error: other is nullptr." );
234+
235+ // Check if any duplicate keys exist.
236+ uint32_t num_keys = other->get_num_keys ().get ();
237+
238+ for (uint32_t i = 0 ; i < num_keys; i++) {
239+ const char * key = other->get_key (i).get ();
240+ ET_CHECK_OR_RETURN_ERROR (
241+ key_to_map_index_.find (key) == key_to_map_index_.end (),
242+ InvalidArgument,
243+ " Merge error: key %s already exists in the named_data_map." ,
244+ key);
245+ }
246+ // Place keys into the map.
247+ for (uint32_t i = 0 ; i < num_keys; i++) {
248+ const char * key = other->get_key (i).get ();
249+ key_to_map_index_[key] = merged_maps_.size ();
250+ }
251+
252+ merged_maps_.push_back (other);
253+ return Error::Ok;
194254}
195255
196256/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load (
@@ -261,8 +321,18 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
261321 InvalidExternalData,
262322 " FlatTensor segments is nullptr, malformed PTD file." );
263323
324+ // Add keys to the map.
325+ std::unordered_map<std::string, int32_t > key_to_map_index;
326+ for (int i = 0 ; i < flat_tensor->named_data ()->size (); i++) {
327+ const auto * named_data = flat_tensor->named_data ()->Get (i);
328+ key_to_map_index[named_data->key ()->c_str ()] = -1 ;
329+ }
264330 return FlatTensorDataMap (
265- fh.get (), std::move (flat_tensor_data.get ()), flat_tensor, loader);
331+ fh.get (),
332+ std::move (flat_tensor_data.get ()),
333+ flat_tensor,
334+ loader,
335+ std::move (key_to_map_index));
266336}
267337
268338} // namespace extension
0 commit comments