diff --git a/exir/backend/test/demos/rpc/ExecutorBackend.cpp b/exir/backend/test/demos/rpc/ExecutorBackend.cpp index 7632e4ad33c..570c6ff7982 100644 --- a/exir/backend/test/demos/rpc/ExecutorBackend.cpp +++ b/exir/backend/test/demos/rpc/ExecutorBackend.cpp @@ -158,7 +158,7 @@ class ExecutorBackend final : public ::executorch::runtime::BackendInterface { new (client_memory_manager) MemoryManager(client_method_allocator, client_planned_memory); - const NamedDataMap* named_data_map = context.get_named_data_map(); + NamedDataMap* named_data_map = context.get_named_data_map(); // Construct the client Method Result method_res = client_program->load_method( "forward", diff --git a/runtime/backend/backend_init_context.h b/runtime/backend/backend_init_context.h index 5a4b70e0dbc..3ce42118961 100644 --- a/runtime/backend/backend_init_context.h +++ b/runtime/backend/backend_init_context.h @@ -23,7 +23,7 @@ class BackendInitContext final { MemoryAllocator* runtime_allocator, EventTracer* event_tracer = nullptr, const char* method_name = nullptr, - const NamedDataMap* named_data_map = nullptr) + NamedDataMap* named_data_map = nullptr) : runtime_allocator_(runtime_allocator), #ifdef ET_EVENT_TRACER_ENABLED event_tracer_(event_tracer), @@ -65,7 +65,7 @@ class BackendInitContext final { /** Get the named data map from ExecuTorch runtime. * This provides a way for backends to retrieve data blobs by key. */ - const NamedDataMap* get_named_data_map() const { + NamedDataMap* get_named_data_map() const { return named_data_map_; } @@ -73,7 +73,7 @@ class BackendInitContext final { MemoryAllocator* runtime_allocator_ = nullptr; EventTracer* event_tracer_ = nullptr; const char* method_name_ = nullptr; - const NamedDataMap* named_data_map_ = nullptr; + NamedDataMap* named_data_map_ = nullptr; }; } // namespace ET_RUNTIME_NAMESPACE diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index fe44f49e7e8..adcc4dc6e78 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -741,7 +741,7 @@ Result Method::load( const Program* program, MemoryManager* memory_manager, EventTracer* event_tracer, - const NamedDataMap* named_data_map) { + NamedDataMap* named_data_map) { MemoryAllocator* temp_allocator = memory_manager->temp_allocator(); if (temp_allocator == nullptr) { PlatformMemoryAllocator* platform_allocator = @@ -766,7 +766,7 @@ Result Method::load( Error Method::init( executorch_flatbuffer::ExecutionPlan* s_plan, - const NamedDataMap* named_data_map) { + NamedDataMap* named_data_map) { EXECUTORCH_SCOPE_PROF("Method::init"); internal::EventTracerProfileMethodScope event_tracer_profile_scope = internal::EventTracerProfileMethodScope(event_tracer_, "Method::init"); @@ -800,21 +800,23 @@ Error Method::init( return Error::MemoryAllocationFailed; } - // Get NamedDataMap, if it exists. - const NamedDataMap* pte_data_map = nullptr; - Result pte_data_map_res = - program_->get_named_data_map(); - if (pte_data_map_res.ok()) { - pte_data_map = pte_data_map_res.get(); - } - - ET_CHECK_OR_RETURN_ERROR( - !(pte_data_map && named_data_map), - NotSupported, - "NamedDataMap merge not supported; both pte_data_map and named_data_map are non-empty. If you see this error please file an issue at https://github.com/pytorch/executorch/issues"); - - if (!named_data_map || named_data_map->get_num_keys().get() == 0) { - named_data_map = pte_data_map; + // Resolve NamedDataMaps. + auto pte_data_map = program_->get_named_data_map(); + if (pte_data_map.ok()) { + if (named_data_map != nullptr) { + Error error = named_data_map->merge(pte_data_map.get()); + ET_CHECK_OR_RETURN_ERROR( + error == Error::Ok, + InvalidExternalData, + "Failed to merge named_data_map with pte_data_map."); + } else { + named_data_map = const_cast(pte_data_map.get()); + } + } else if (pte_data_map.error() != Error::NotFound) { + // Error::NotFound is expected if the program does not have shared data. + // In this case, expect pte_data_map to be empty/null, and we can proceed + // with the named data map only. + return pte_data_map.error(); } // n_delegate_ counts the number of successfully-initialized delegates for diff --git a/runtime/executor/method.h b/runtime/executor/method.h index 99a6aea439f..d406d7acc15 100644 --- a/runtime/executor/method.h +++ b/runtime/executor/method.h @@ -324,7 +324,7 @@ class Method final { const Program* program, MemoryManager* memory_manager, EventTracer* event_tracer, - const NamedDataMap* named_data_map); + NamedDataMap* named_data_map); /** * Initialize the method from its serialized representation. @@ -333,7 +333,7 @@ class Method final { */ ET_NODISCARD Error init( executorch_flatbuffer::ExecutionPlan* s_plan, - const NamedDataMap* named_data_map); + NamedDataMap* named_data_map); /// Returns true if the Method was successfully initialized. inline bool initialized() const { diff --git a/runtime/executor/program.cpp b/runtime/executor/program.cpp index 238c806b1d6..062842609ea 100644 --- a/runtime/executor/program.cpp +++ b/runtime/executor/program.cpp @@ -258,7 +258,7 @@ Result Program::load_method( const char* method_name, MemoryManager* memory_manager, EventTracer* event_tracer, - const NamedDataMap* named_data_map) const { + NamedDataMap* named_data_map) const { EXECUTORCH_SCOPE_PROF("Program::load_method"); internal::event_tracer_create_event_block(event_tracer, "Default"); internal::EventTracerProfileMethodScope event_tracer_scope = @@ -372,9 +372,9 @@ Result Program::get_constant_buffer_data( } } -Result Program::get_named_data_map() const { +Result Program::get_named_data_map() const { if (pte_data_map_.has_value()) { - return &pte_data_map_.value(); + return const_cast(&pte_data_map_.value()); } return Error::NotFound; } diff --git a/runtime/executor/program.h b/runtime/executor/program.h index 9670fd7c79f..0b63b6a8c79 100644 --- a/runtime/executor/program.h +++ b/runtime/executor/program.h @@ -113,7 +113,7 @@ class Program final { * Get the named data map from the program. * @return The named data map. */ - Result get_named_data_map() const; + Result get_named_data_map() const; /** * Returns the number of methods in the program. @@ -148,7 +148,7 @@ class Program final { const char* method_name, MemoryManager* memory_manager, EventTracer* event_tracer = nullptr, - const NamedDataMap* named_data_map = nullptr) const; + NamedDataMap* named_data_map = nullptr) const; /** * Gathers metadata for the named method. diff --git a/runtime/executor/test/program_test.cpp b/runtime/executor/test/program_test.cpp index 962bf8f548a..1a4d819cbed 100644 --- a/runtime/executor/test/program_test.cpp +++ b/runtime/executor/test/program_test.cpp @@ -378,7 +378,7 @@ TEST_F(ProgramTest, GetNamedDataMap_Fail) { // Get the named data map. Expect to fail, as add.pte does not have any // named data segments. - Result named_data_map = + Result named_data_map = program->get_named_data_map(); EXPECT_EQ(named_data_map.error(), Error::NotFound); }