From 19c16cb8adf8715df27366fa63b78c567c589c12 Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 27 Jun 2025 17:03:50 -0700 Subject: [PATCH] Add MergedDataMap to method Differential Revision: [D77472917](https://our.internmc.facebook.com/intern/diff/D77472917/) [ghstack-poisoned] --- runtime/executor/method.cpp | 66 +++++++++++++++++++------------ runtime/executor/method.h | 6 +++ runtime/executor/targets.bzl | 1 + runtime/executor/test/targets.bzl | 1 + 4 files changed, 48 insertions(+), 26 deletions(-) diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index fe44f49e7e8..679c2194270 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -328,9 +329,9 @@ Result Method::get_num_external_constants() { return n_external_constants; } -Error Method::parse_external_constants(const NamedDataMap* named_data_map) { +Error Method::parse_external_constants(const NamedDataMap* external_data_map) { ET_CHECK_OR_RETURN_ERROR( - named_data_map != nullptr, InvalidState, "named_data_map is null"); + external_data_map != nullptr, InvalidState, "external_data_map is null"); auto flatbuffer_values = serialization_plan_->values(); size_t n_value = flatbuffer_values->size(); @@ -372,7 +373,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) { continue; } Result tensor_layout = - named_data_map->get_tensor_layout(key); + external_data_map->get_tensor_layout(key); if (!tensor_layout.ok()) { ET_LOG(Info, "Failed to get metadata for key %s", key); return tensor_layout.error(); @@ -387,7 +388,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) { external_constants_[n_external_constants_].key = key; // Save the buffer. - Result buffer = named_data_map->get_data(key); + Result buffer = external_data_map->get_data(key); ET_CHECK_OR_RETURN_ERROR( buffer.ok(), InvalidExternalData, @@ -400,7 +401,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) { return Error::Ok; } -Error Method::parse_values(const NamedDataMap* named_data_map) { +Error Method::parse_values(const NamedDataMap* external_data_map) { auto flatbuffer_values = serialization_plan_->values(); ET_CHECK_OR_RETURN_ERROR( flatbuffer_values != nullptr, InvalidProgram, "Missing values"); @@ -428,7 +429,7 @@ Error Method::parse_values(const NamedDataMap* named_data_map) { if (external_constants_ == nullptr) { return Error::MemoryAllocationFailed; } - Error err = parse_external_constants(named_data_map); + Error err = parse_external_constants(external_data_map); if (err != Error::Ok) { return err; } @@ -541,7 +542,7 @@ Error Method::parse_values(const NamedDataMap* named_data_map) { program_, memory_manager_, static_cast(val), - named_data_map, + external_data_map, Span(external_constants_, n_external_constants_)); if (!t.ok()) { ET_LOG( @@ -741,7 +742,7 @@ Result Method::load( const Program* program, MemoryManager* memory_manager, EventTracer* event_tracer, - const NamedDataMap* named_data_map) { + const NamedDataMap* external_data_map) { MemoryAllocator* temp_allocator = memory_manager->temp_allocator(); if (temp_allocator == nullptr) { PlatformMemoryAllocator* platform_allocator = @@ -755,7 +756,7 @@ Result Method::load( } Method method(program, memory_manager, event_tracer, temp_allocator); ET_LOG(Debug, "Loading method: %s.", s_plan->name()->c_str()); - Error err = method.init(s_plan, named_data_map); + Error err = method.init(s_plan, external_data_map); if (err != Error::Ok) { return err; } else { @@ -766,7 +767,7 @@ Result Method::load( Error Method::init( executorch_flatbuffer::ExecutionPlan* s_plan, - const NamedDataMap* named_data_map) { + const NamedDataMap* external_data_map) { EXECUTORCH_SCOPE_PROF("Method::init"); internal::EventTracerProfileMethodScope event_tracer_profile_scope = internal::EventTracerProfileMethodScope(event_tracer_, "Method::init"); @@ -783,7 +784,7 @@ Error Method::init( { // Parse the elements of the values_ array. - Error err = parse_values(named_data_map); + Error err = parse_values(external_data_map); if (err != Error::Ok) { return err; } @@ -800,23 +801,34 @@ Error Method::init( return Error::MemoryAllocationFailed; } - // Get NamedDataMap, if it exists. - const NamedDataMap* pte_data_map = nullptr; - Result pte_data_map_res = - program_->get_named_data_map(); - if (pte_data_map_res.ok()) { - pte_data_map = pte_data_map_res.get(); - } - + // Merge NamedDataMaps. + auto pte_data_map = program_->get_named_data_map(); ET_CHECK_OR_RETURN_ERROR( - !(pte_data_map && named_data_map), - NotSupported, - "NamedDataMap merge not supported; both pte_data_map and named_data_map are non-empty. If you see this error please file an issue at https://github.com/pytorch/executorch/issues"); - - if (!named_data_map || named_data_map->get_num_keys().get() == 0) { - named_data_map = pte_data_map; + pte_data_map.ok() || pte_data_map.error() == Error::NotFound, + InvalidProgram, + "Failed to get named data map from program: 0x%" PRIx32, + static_cast(pte_data_map.error())); + + const NamedDataMap* named_data_map = nullptr; + // Merge them. + if (external_data_map && pte_data_map.ok()) { + const std::array data_maps = { + external_data_map, pte_data_map.ok() ? pte_data_map.get() : nullptr}; + + auto merged = MergedDataMap<2>::load(data_maps); + if (!merged.ok()) { + return merged.error(); + } + merged_data_map_ = method_allocator->allocateInstance>(); + if (merged_data_map_ == nullptr) { + return Error::MemoryAllocationFailed; + } + new (merged_data_map_) MergedDataMap<2>(std::move(merged.get())); + } else if (external_data_map) { + named_data_map = external_data_map; + } else if (pte_data_map.ok()) { + named_data_map = pte_data_map.get(); } - // n_delegate_ counts the number of successfully-initialized delegates for // ~Method() to clean up, and is incremented at the bottom of the loop. This // makes it safe for errors to return without updating any state. @@ -1680,6 +1692,8 @@ Method::~Method() { for (const auto i : c10::irange(n_external_constants_)) { external_constants_[i].buffer.~FreeableBuffer(); } + // Free the MergedDataMap + merged_data_map_->~MergedDataMap<2>(); // All other fields are trivially destructible. } } // namespace ET_RUNTIME_NAMESPACE diff --git a/runtime/executor/method.h b/runtime/executor/method.h index 99a6aea439f..841b91a79a2 100644 --- a/runtime/executor/method.h +++ b/runtime/executor/method.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -76,6 +77,7 @@ class Method final { delegates_(rhs.delegates_), n_chains_(rhs.n_chains_), chains_(rhs.chains_), + merged_data_map_(std::move(rhs.merged_data_map_)), external_constants_(rhs.external_constants_), n_external_constants_(rhs.n_external_constants_), init_state_(rhs.init_state_) { @@ -85,6 +87,8 @@ class Method final { rhs.values_ = nullptr; rhs.n_delegate_ = 0; rhs.delegates_ = nullptr; + + rhs.merged_data_map_ = nullptr; rhs.n_external_constants_ = 0; rhs.external_constants_ = nullptr; @@ -314,6 +318,7 @@ class Method final { delegates_(nullptr), n_chains_(0), chains_(nullptr), + merged_data_map_(nullptr), external_constants_(nullptr), n_external_constants_(0), init_state_(InitializationState::Uninitialized) {} @@ -364,6 +369,7 @@ class Method final { size_t n_chains_; Chain* chains_; + MergedDataMap<2>* merged_data_map_; NamedData* external_constants_; size_t n_external_constants_ = 0; diff --git a/runtime/executor/targets.bzl b/runtime/executor/targets.bzl index 98165373b73..424cc3e147b 100644 --- a/runtime/executor/targets.bzl +++ b/runtime/executor/targets.bzl @@ -117,6 +117,7 @@ def define_common_targets(): exported_deps = [ ":memory_manager", ":pte_data_map" + aten_suffix, + ":merged_data_map" + aten_suffix, "//executorch/runtime/backend:interface" + aten_suffix, "//executorch/runtime/core:core", "//executorch/runtime/core:named_data_map" + aten_suffix, diff --git a/runtime/executor/test/targets.bzl b/runtime/executor/test/targets.bzl index 7b4672e4414..1174b01f42b 100644 --- a/runtime/executor/test/targets.bzl +++ b/runtime/executor/test/targets.bzl @@ -163,6 +163,7 @@ def define_common_targets(is_fbcode = False): ], deps = [ ":managed_memory_manager", + "//executorch/runtime/executor:merged_data_map", "//executorch/runtime/executor:program", "//executorch/extension/data_loader:file_data_loader", "//executorch/extension/flat_tensor:flat_tensor_data_map",