2020#include < executorch/runtime/core/named_data_map.h>
2121#include < executorch/runtime/core/span.h>
2222#include < executorch/runtime/executor/memory_manager.h>
23+ #include < executorch/runtime/executor/merged_data_map.h>
2324#include < executorch/runtime/executor/platform_memory_allocator.h>
2425#include < executorch/runtime/executor/program.h>
2526#include < executorch/runtime/executor/tensor_parser.h>
@@ -328,9 +329,9 @@ Result<size_t> Method::get_num_external_constants() {
328329 return n_external_constants;
329330}
330331
331- Error Method::parse_external_constants (const NamedDataMap* named_data_map ) {
332+ Error Method::parse_external_constants (const NamedDataMap* external_data_map ) {
332333 ET_CHECK_OR_RETURN_ERROR (
333- named_data_map != nullptr , InvalidState, " named_data_map is null" );
334+ external_data_map != nullptr , InvalidState, " external_data_map is null" );
334335 auto flatbuffer_values = serialization_plan_->values ();
335336 size_t n_value = flatbuffer_values->size ();
336337
@@ -372,7 +373,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
372373 continue ;
373374 }
374375 Result<const TensorLayout> tensor_layout =
375- named_data_map ->get_tensor_layout (key);
376+ external_data_map ->get_tensor_layout (key);
376377 if (!tensor_layout.ok ()) {
377378 ET_LOG (Info, " Failed to get metadata for key %s" , key);
378379 return tensor_layout.error ();
@@ -387,7 +388,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
387388 external_constants_[n_external_constants_].key = key;
388389
389390 // Save the buffer.
390- Result<FreeableBuffer> buffer = named_data_map ->get_data (key);
391+ Result<FreeableBuffer> buffer = external_data_map ->get_data (key);
391392 ET_CHECK_OR_RETURN_ERROR (
392393 buffer.ok (),
393394 InvalidExternalData,
@@ -400,7 +401,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
400401 return Error::Ok;
401402}
402403
403- Error Method::parse_values (const NamedDataMap* named_data_map ) {
404+ Error Method::parse_values (const NamedDataMap* external_data_map ) {
404405 auto flatbuffer_values = serialization_plan_->values ();
405406 ET_CHECK_OR_RETURN_ERROR (
406407 flatbuffer_values != nullptr , InvalidProgram, " Missing values" );
@@ -428,7 +429,7 @@ Error Method::parse_values(const NamedDataMap* named_data_map) {
428429 if (external_constants_ == nullptr ) {
429430 return Error::MemoryAllocationFailed;
430431 }
431- Error err = parse_external_constants (named_data_map );
432+ Error err = parse_external_constants (external_data_map );
432433 if (err != Error::Ok) {
433434 return err;
434435 }
@@ -541,7 +542,7 @@ Error Method::parse_values(const NamedDataMap* named_data_map) {
541542 program_,
542543 memory_manager_,
543544 static_cast <const executorch_flatbuffer::Tensor*>(val),
544- named_data_map ,
545+ external_data_map ,
545546 Span<NamedData>(external_constants_, n_external_constants_));
546547 if (!t.ok ()) {
547548 ET_LOG (
@@ -741,7 +742,7 @@ Result<Method> Method::load(
741742 const Program* program,
742743 MemoryManager* memory_manager,
743744 EventTracer* event_tracer,
744- const NamedDataMap* named_data_map ) {
745+ const NamedDataMap* external_data_map ) {
745746 MemoryAllocator* temp_allocator = memory_manager->temp_allocator ();
746747 if (temp_allocator == nullptr ) {
747748 PlatformMemoryAllocator* platform_allocator =
@@ -755,7 +756,7 @@ Result<Method> Method::load(
755756 }
756757 Method method (program, memory_manager, event_tracer, temp_allocator);
757758 ET_LOG (Debug, " Loading method: %s." , s_plan->name ()->c_str ());
758- Error err = method.init (s_plan, named_data_map );
759+ Error err = method.init (s_plan, external_data_map );
759760 if (err != Error::Ok) {
760761 return err;
761762 } else {
@@ -766,7 +767,7 @@ Result<Method> Method::load(
766767
767768Error Method::init (
768769 executorch_flatbuffer::ExecutionPlan* s_plan,
769- const NamedDataMap* named_data_map ) {
770+ const NamedDataMap* external_data_map ) {
770771 EXECUTORCH_SCOPE_PROF (" Method::init" );
771772 internal::EventTracerProfileMethodScope event_tracer_profile_scope =
772773 internal::EventTracerProfileMethodScope (event_tracer_, " Method::init" );
@@ -783,7 +784,7 @@ Error Method::init(
783784
784785 {
785786 // Parse the elements of the values_ array.
786- Error err = parse_values (named_data_map );
787+ Error err = parse_values (external_data_map );
787788 if (err != Error::Ok) {
788789 return err;
789790 }
@@ -800,23 +801,34 @@ Error Method::init(
800801 return Error::MemoryAllocationFailed;
801802 }
802803
803- // Get NamedDataMap, if it exists.
804- const NamedDataMap* pte_data_map = nullptr ;
805- Result<const NamedDataMap*> pte_data_map_res =
806- program_->get_named_data_map ();
807- if (pte_data_map_res.ok ()) {
808- pte_data_map = pte_data_map_res.get ();
809- }
810-
804+ // Merge NamedDataMaps.
805+ auto pte_data_map = program_->get_named_data_map ();
811806 ET_CHECK_OR_RETURN_ERROR (
812- !(pte_data_map && named_data_map),
813- NotSupported,
814- " NamedDataMap merge not supported; both pte_data_map and named_data_map are non-empty. If you see this error please file an issue at https://github.com/pytorch/executorch/issues" );
815-
816- if (!named_data_map || named_data_map->get_num_keys ().get () == 0 ) {
817- named_data_map = pte_data_map;
807+ pte_data_map.ok () || pte_data_map.error () == Error::NotFound,
808+ InvalidProgram,
809+ " Failed to get named data map from program: 0x%" PRIx32,
810+ static_cast <uint32_t >(pte_data_map.error ()));
811+
812+ const NamedDataMap* named_data_map = nullptr ;
813+ // Merge them.
814+ if (external_data_map && pte_data_map.ok ()) {
815+ const std::array<const NamedDataMap*, 2 > data_maps = {
816+ external_data_map, pte_data_map.ok () ? pte_data_map.get () : nullptr };
817+
818+ auto merged = MergedDataMap<2 >::load (data_maps);
819+ if (!merged.ok ()) {
820+ return merged.error ();
821+ }
822+ merged_data_map_ = method_allocator->allocateInstance <MergedDataMap<2 >>();
823+ if (merged_data_map_ == nullptr ) {
824+ return Error::MemoryAllocationFailed;
825+ }
826+ new (merged_data_map_) MergedDataMap<2 >(std::move (merged.get ()));
827+ } else if (external_data_map) {
828+ named_data_map = external_data_map;
829+ } else if (pte_data_map.ok ()) {
830+ named_data_map = pte_data_map.get ();
818831 }
819-
820832 // n_delegate_ counts the number of successfully-initialized delegates for
821833 // ~Method() to clean up, and is incremented at the bottom of the loop. This
822834 // makes it safe for errors to return without updating any state.
@@ -1680,6 +1692,8 @@ Method::~Method() {
16801692 for (const auto i : c10::irange (n_external_constants_)) {
16811693 external_constants_[i].buffer .~FreeableBuffer ();
16821694 }
1695+ // Free the MergedDataMap
1696+ merged_data_map_->~MergedDataMap<2 >();
16831697 // All other fields are trivially destructible.
16841698}
16851699} // namespace ET_RUNTIME_NAMESPACE
0 commit comments