Skip to content

Commit 19c16cb

Browse files
committed
Add MergedDataMap to method
Differential Revision: [D77472917](https://our.internmc.facebook.com/intern/diff/D77472917/) [ghstack-poisoned]
1 parent 59bb3df commit 19c16cb

File tree

4 files changed

+48
-26
lines changed

4 files changed

+48
-26
lines changed

runtime/executor/method.cpp

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <executorch/runtime/core/named_data_map.h>
2121
#include <executorch/runtime/core/span.h>
2222
#include <executorch/runtime/executor/memory_manager.h>
23+
#include <executorch/runtime/executor/merged_data_map.h>
2324
#include <executorch/runtime/executor/platform_memory_allocator.h>
2425
#include <executorch/runtime/executor/program.h>
2526
#include <executorch/runtime/executor/tensor_parser.h>
@@ -328,9 +329,9 @@ Result<size_t> Method::get_num_external_constants() {
328329
return n_external_constants;
329330
}
330331

331-
Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
332+
Error Method::parse_external_constants(const NamedDataMap* external_data_map) {
332333
ET_CHECK_OR_RETURN_ERROR(
333-
named_data_map != nullptr, InvalidState, "named_data_map is null");
334+
external_data_map != nullptr, InvalidState, "external_data_map is null");
334335
auto flatbuffer_values = serialization_plan_->values();
335336
size_t n_value = flatbuffer_values->size();
336337

@@ -372,7 +373,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
372373
continue;
373374
}
374375
Result<const TensorLayout> tensor_layout =
375-
named_data_map->get_tensor_layout(key);
376+
external_data_map->get_tensor_layout(key);
376377
if (!tensor_layout.ok()) {
377378
ET_LOG(Info, "Failed to get metadata for key %s", key);
378379
return tensor_layout.error();
@@ -387,7 +388,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
387388
external_constants_[n_external_constants_].key = key;
388389

389390
// Save the buffer.
390-
Result<FreeableBuffer> buffer = named_data_map->get_data(key);
391+
Result<FreeableBuffer> buffer = external_data_map->get_data(key);
391392
ET_CHECK_OR_RETURN_ERROR(
392393
buffer.ok(),
393394
InvalidExternalData,
@@ -400,7 +401,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
400401
return Error::Ok;
401402
}
402403

403-
Error Method::parse_values(const NamedDataMap* named_data_map) {
404+
Error Method::parse_values(const NamedDataMap* external_data_map) {
404405
auto flatbuffer_values = serialization_plan_->values();
405406
ET_CHECK_OR_RETURN_ERROR(
406407
flatbuffer_values != nullptr, InvalidProgram, "Missing values");
@@ -428,7 +429,7 @@ Error Method::parse_values(const NamedDataMap* named_data_map) {
428429
if (external_constants_ == nullptr) {
429430
return Error::MemoryAllocationFailed;
430431
}
431-
Error err = parse_external_constants(named_data_map);
432+
Error err = parse_external_constants(external_data_map);
432433
if (err != Error::Ok) {
433434
return err;
434435
}
@@ -541,7 +542,7 @@ Error Method::parse_values(const NamedDataMap* named_data_map) {
541542
program_,
542543
memory_manager_,
543544
static_cast<const executorch_flatbuffer::Tensor*>(val),
544-
named_data_map,
545+
external_data_map,
545546
Span<NamedData>(external_constants_, n_external_constants_));
546547
if (!t.ok()) {
547548
ET_LOG(
@@ -741,7 +742,7 @@ Result<Method> Method::load(
741742
const Program* program,
742743
MemoryManager* memory_manager,
743744
EventTracer* event_tracer,
744-
const NamedDataMap* named_data_map) {
745+
const NamedDataMap* external_data_map) {
745746
MemoryAllocator* temp_allocator = memory_manager->temp_allocator();
746747
if (temp_allocator == nullptr) {
747748
PlatformMemoryAllocator* platform_allocator =
@@ -755,7 +756,7 @@ Result<Method> Method::load(
755756
}
756757
Method method(program, memory_manager, event_tracer, temp_allocator);
757758
ET_LOG(Debug, "Loading method: %s.", s_plan->name()->c_str());
758-
Error err = method.init(s_plan, named_data_map);
759+
Error err = method.init(s_plan, external_data_map);
759760
if (err != Error::Ok) {
760761
return err;
761762
} else {
@@ -766,7 +767,7 @@ Result<Method> Method::load(
766767

767768
Error Method::init(
768769
executorch_flatbuffer::ExecutionPlan* s_plan,
769-
const NamedDataMap* named_data_map) {
770+
const NamedDataMap* external_data_map) {
770771
EXECUTORCH_SCOPE_PROF("Method::init");
771772
internal::EventTracerProfileMethodScope event_tracer_profile_scope =
772773
internal::EventTracerProfileMethodScope(event_tracer_, "Method::init");
@@ -783,7 +784,7 @@ Error Method::init(
783784

784785
{
785786
// Parse the elements of the values_ array.
786-
Error err = parse_values(named_data_map);
787+
Error err = parse_values(external_data_map);
787788
if (err != Error::Ok) {
788789
return err;
789790
}
@@ -800,23 +801,34 @@ Error Method::init(
800801
return Error::MemoryAllocationFailed;
801802
}
802803

803-
// Get NamedDataMap, if it exists.
804-
const NamedDataMap* pte_data_map = nullptr;
805-
Result<const NamedDataMap*> pte_data_map_res =
806-
program_->get_named_data_map();
807-
if (pte_data_map_res.ok()) {
808-
pte_data_map = pte_data_map_res.get();
809-
}
810-
804+
// Merge NamedDataMaps.
805+
auto pte_data_map = program_->get_named_data_map();
811806
ET_CHECK_OR_RETURN_ERROR(
812-
!(pte_data_map && named_data_map),
813-
NotSupported,
814-
"NamedDataMap merge not supported; both pte_data_map and named_data_map are non-empty. If you see this error please file an issue at https://github.com/pytorch/executorch/issues");
815-
816-
if (!named_data_map || named_data_map->get_num_keys().get() == 0) {
817-
named_data_map = pte_data_map;
807+
pte_data_map.ok() || pte_data_map.error() == Error::NotFound,
808+
InvalidProgram,
809+
"Failed to get named data map from program: 0x%" PRIx32,
810+
static_cast<uint32_t>(pte_data_map.error()));
811+
812+
const NamedDataMap* named_data_map = nullptr;
813+
// Merge them.
814+
if (external_data_map && pte_data_map.ok()) {
815+
const std::array<const NamedDataMap*, 2> data_maps = {
816+
external_data_map, pte_data_map.ok() ? pte_data_map.get() : nullptr};
817+
818+
auto merged = MergedDataMap<2>::load(data_maps);
819+
if (!merged.ok()) {
820+
return merged.error();
821+
}
822+
merged_data_map_ = method_allocator->allocateInstance<MergedDataMap<2>>();
823+
if (merged_data_map_ == nullptr) {
824+
return Error::MemoryAllocationFailed;
825+
}
826+
new (merged_data_map_) MergedDataMap<2>(std::move(merged.get()));
827+
} else if (external_data_map) {
828+
named_data_map = external_data_map;
829+
} else if (pte_data_map.ok()) {
830+
named_data_map = pte_data_map.get();
818831
}
819-
820832
// n_delegate_ counts the number of successfully-initialized delegates for
821833
// ~Method() to clean up, and is incremented at the bottom of the loop. This
822834
// makes it safe for errors to return without updating any state.
@@ -1680,6 +1692,8 @@ Method::~Method() {
16801692
for (const auto i : c10::irange(n_external_constants_)) {
16811693
external_constants_[i].buffer.~FreeableBuffer();
16821694
}
1695+
// Free the MergedDataMap
1696+
merged_data_map_->~MergedDataMap<2>();
16831697
// All other fields are trivially destructible.
16841698
}
16851699
} // namespace ET_RUNTIME_NAMESPACE

runtime/executor/method.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <executorch/runtime/core/named_data_map.h>
2121
#include <executorch/runtime/core/span.h>
2222
#include <executorch/runtime/executor/memory_manager.h>
23+
#include <executorch/runtime/executor/merged_data_map.h>
2324
#include <executorch/runtime/executor/method_meta.h>
2425
#include <executorch/runtime/platform/compiler.h>
2526

@@ -76,6 +77,7 @@ class Method final {
7677
delegates_(rhs.delegates_),
7778
n_chains_(rhs.n_chains_),
7879
chains_(rhs.chains_),
80+
merged_data_map_(std::move(rhs.merged_data_map_)),
7981
external_constants_(rhs.external_constants_),
8082
n_external_constants_(rhs.n_external_constants_),
8183
init_state_(rhs.init_state_) {
@@ -85,6 +87,8 @@ class Method final {
8587
rhs.values_ = nullptr;
8688
rhs.n_delegate_ = 0;
8789
rhs.delegates_ = nullptr;
90+
91+
rhs.merged_data_map_ = nullptr;
8892
rhs.n_external_constants_ = 0;
8993
rhs.external_constants_ = nullptr;
9094

@@ -314,6 +318,7 @@ class Method final {
314318
delegates_(nullptr),
315319
n_chains_(0),
316320
chains_(nullptr),
321+
merged_data_map_(nullptr),
317322
external_constants_(nullptr),
318323
n_external_constants_(0),
319324
init_state_(InitializationState::Uninitialized) {}
@@ -364,6 +369,7 @@ class Method final {
364369
size_t n_chains_;
365370
Chain* chains_;
366371

372+
MergedDataMap<2>* merged_data_map_;
367373
NamedData* external_constants_;
368374
size_t n_external_constants_ = 0;
369375

runtime/executor/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def define_common_targets():
117117
exported_deps = [
118118
":memory_manager",
119119
":pte_data_map" + aten_suffix,
120+
":merged_data_map" + aten_suffix,
120121
"//executorch/runtime/backend:interface" + aten_suffix,
121122
"//executorch/runtime/core:core",
122123
"//executorch/runtime/core:named_data_map" + aten_suffix,

runtime/executor/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def define_common_targets(is_fbcode = False):
163163
],
164164
deps = [
165165
":managed_memory_manager",
166+
"//executorch/runtime/executor:merged_data_map",
166167
"//executorch/runtime/executor:program",
167168
"//executorch/extension/data_loader:file_data_loader",
168169
"//executorch/extension/flat_tensor:flat_tensor_data_map",

0 commit comments

Comments
 (0)