diff --git a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp index 5a94b47b954..37e1cd2edac 100644 --- a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp +++ b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp @@ -33,8 +33,8 @@ class FlatTensorDataMapTest : public ::testing::Test { // first. executorch::runtime::runtime_init(); - // Load data map. The eager linear model is defined at: - // //executorch/test/models/linear_model.py + // Load data map. The eager addmul model is defined at: + // //executorch/test/models/export_program.py const char* path = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"); Result loader = FileDataLoader::from(path); ASSERT_EQ(loader.error(), Error::Ok); diff --git a/runtime/executor/merged_data_map.h b/runtime/executor/merged_data_map.h new file mode 100644 index 00000000000..0f0175098ae --- /dev/null +++ b/runtime/executor/merged_data_map.h @@ -0,0 +1,149 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace ET_RUNTIME_NAMESPACE { +namespace internal { + +/** + * A NamedDataMap implementation that wraps other NamedDataMaps. + */ +class MergedDataMap final : public NamedDataMap { + public: + /** + * Creates a new NamedDataMap that wraps two other data maps. + * + * @param[in] first The first NamedDataMap to merge. + * @param[in] second The second NamedDataMap to merge. + * Note: the data maps must outlive the MergedDataMap instance. + */ + static Result load( + const NamedDataMap* first, + const NamedDataMap* second) { + ET_CHECK_OR_RETURN_ERROR( + first != nullptr && second != nullptr, + InvalidArgument, + "Input data map is null."); + + // Check for duplicate keys. + for (uint32_t k = 0; k < first->get_num_keys().get(); k++) { + const auto key = first->get_key(k).get(); + ET_CHECK_OR_RETURN_ERROR( + second->get_tensor_layout(key).error() == Error::NotFound, + InvalidArgument, + "Duplicate key %s.", + key); + } + return MergedDataMap(first, second); + } + + /** + * Retrieve the tensor_layout for the specified key. + * + * @param[in] key The name of the tensor to get metadata on. + * + * @return Error::NotFound if the key is not present. + */ + ET_NODISCARD + Result get_tensor_layout( + executorch::aten::string_view key) const override { + auto layout = first_->get_tensor_layout(key); + if (layout.ok()) { + return layout.get(); + } + if (layout.error() != Error::NotFound) { + return layout.error(); + } + return second_->get_tensor_layout(key); + } + + /** + * Retrieve read-only data for the specified key. + * + * @param[in] key The name of the tensor to get data on. + * + * @return error if the key is not present or data cannot be loaded. + */ + ET_NODISCARD + Result get_data( + executorch::aten::string_view key) const override { + auto data = first_->get_data(key); + if (data.error() != Error::NotFound) { + return data; + } + return second_->get_data(key); + } + + /** + * Loads the data of the specified tensor into the provided buffer. + * Not used in the MergedDataMap. + * + * @param[in] key The name of the tensor to get the data of. + * @param[in] buffer The buffer to load data into. Must point to at least + * `size` bytes of memory. + * @param[in] size The number of bytes to load. + * + * @returns an Error indicating if the load was successful. + */ + ET_NODISCARD Error load_data_into( + ET_UNUSED executorch::aten::string_view key, + ET_UNUSED void* buffer, + ET_UNUSED size_t size) const override { + return Error::NotImplemented; + } + + /** + * @returns The number of keys in the map. + */ + ET_NODISCARD Result get_num_keys() const override { + return first_->get_num_keys().get() + second_->get_num_keys().get(); + } + + /** + * @returns The key at the specified index, error if index out of bounds. + */ + ET_NODISCARD Result get_key(uint32_t index) const override { + uint32_t total_num_keys = get_num_keys().get(); + ET_CHECK_OR_RETURN_ERROR( + index >= 0 && index < total_num_keys, + InvalidArgument, + "Index %u out of range of size %u", + index, + total_num_keys); + + if (index < first_->get_num_keys().get()) { + return first_->get_key(index); + } else { + return second_->get_key(index - first_->get_num_keys().get()); + } + } + + MergedDataMap(MergedDataMap&&) noexcept = default; + + ~MergedDataMap() override = default; + + private: + MergedDataMap(const NamedDataMap* first, const NamedDataMap* second) + : first_{first}, second_{second} {} + + // Not copyable or assignable. + MergedDataMap(const MergedDataMap& rhs) = delete; + MergedDataMap& operator=(MergedDataMap&& rhs) noexcept = delete; + MergedDataMap& operator=(const MergedDataMap& rhs) = delete; + + const NamedDataMap* first_; + const NamedDataMap* second_; +}; + +} // namespace internal +} // namespace ET_RUNTIME_NAMESPACE +} // namespace executorch diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index fe44f49e7e8..327c0cb9b6f 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -328,9 +329,9 @@ Result Method::get_num_external_constants() { return n_external_constants; } -Error Method::parse_external_constants(const NamedDataMap* named_data_map) { +Error Method::parse_external_constants(const NamedDataMap* external_data_map) { ET_CHECK_OR_RETURN_ERROR( - named_data_map != nullptr, InvalidState, "named_data_map is null"); + external_data_map != nullptr, InvalidState, "external_data_map is null"); auto flatbuffer_values = serialization_plan_->values(); size_t n_value = flatbuffer_values->size(); @@ -372,7 +373,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) { continue; } Result tensor_layout = - named_data_map->get_tensor_layout(key); + external_data_map->get_tensor_layout(key); if (!tensor_layout.ok()) { ET_LOG(Info, "Failed to get metadata for key %s", key); return tensor_layout.error(); @@ -387,7 +388,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) { external_constants_[n_external_constants_].key = key; // Save the buffer. - Result buffer = named_data_map->get_data(key); + Result buffer = external_data_map->get_data(key); ET_CHECK_OR_RETURN_ERROR( buffer.ok(), InvalidExternalData, @@ -400,7 +401,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) { return Error::Ok; } -Error Method::parse_values(const NamedDataMap* named_data_map) { +Error Method::parse_values(const NamedDataMap* external_data_map) { auto flatbuffer_values = serialization_plan_->values(); ET_CHECK_OR_RETURN_ERROR( flatbuffer_values != nullptr, InvalidProgram, "Missing values"); @@ -428,7 +429,7 @@ Error Method::parse_values(const NamedDataMap* named_data_map) { if (external_constants_ == nullptr) { return Error::MemoryAllocationFailed; } - Error err = parse_external_constants(named_data_map); + Error err = parse_external_constants(external_data_map); if (err != Error::Ok) { return err; } @@ -541,7 +542,7 @@ Error Method::parse_values(const NamedDataMap* named_data_map) { program_, memory_manager_, static_cast(val), - named_data_map, + external_data_map, Span(external_constants_, n_external_constants_)); if (!t.ok()) { ET_LOG( @@ -741,7 +742,7 @@ Result Method::load( const Program* program, MemoryManager* memory_manager, EventTracer* event_tracer, - const NamedDataMap* named_data_map) { + const NamedDataMap* external_data_map) { MemoryAllocator* temp_allocator = memory_manager->temp_allocator(); if (temp_allocator == nullptr) { PlatformMemoryAllocator* platform_allocator = @@ -755,7 +756,7 @@ Result Method::load( } Method method(program, memory_manager, event_tracer, temp_allocator); ET_LOG(Debug, "Loading method: %s.", s_plan->name()->c_str()); - Error err = method.init(s_plan, named_data_map); + Error err = method.init(s_plan, external_data_map); if (err != Error::Ok) { return err; } else { @@ -766,7 +767,7 @@ Result Method::load( Error Method::init( executorch_flatbuffer::ExecutionPlan* s_plan, - const NamedDataMap* named_data_map) { + const NamedDataMap* external_data_map) { EXECUTORCH_SCOPE_PROF("Method::init"); internal::EventTracerProfileMethodScope event_tracer_profile_scope = internal::EventTracerProfileMethodScope(event_tracer_, "Method::init"); @@ -783,7 +784,7 @@ Error Method::init( { // Parse the elements of the values_ array. - Error err = parse_values(named_data_map); + Error err = parse_values(external_data_map); if (err != Error::Ok) { return err; } @@ -800,21 +801,34 @@ Error Method::init( return Error::MemoryAllocationFailed; } - // Get NamedDataMap, if it exists. - const NamedDataMap* pte_data_map = nullptr; - Result pte_data_map_res = - program_->get_named_data_map(); - if (pte_data_map_res.ok()) { - pte_data_map = pte_data_map_res.get(); - } - + // Get PTE data map, if it exists. + auto pte_data_map = program_->get_named_data_map(); ET_CHECK_OR_RETURN_ERROR( - !(pte_data_map && named_data_map), - NotSupported, - "NamedDataMap merge not supported; both pte_data_map and named_data_map are non-empty. If you see this error please file an issue at https://github.com/pytorch/executorch/issues"); - - if (!named_data_map || named_data_map->get_num_keys().get() == 0) { - named_data_map = pte_data_map; + pte_data_map.ok() || pte_data_map.error() == Error::NotFound, + InvalidProgram, + "Failed to get named data map from program: 0x%" PRIx32, + static_cast(pte_data_map.error())); + + const NamedDataMap* named_data_map = nullptr; + if (external_data_map && pte_data_map.ok()) { + // Merge external_data_map and pte_data_map if both are present. + auto merged = + internal::MergedDataMap::load(external_data_map, pte_data_map.get()); + if (!merged.ok()) { + return merged.error(); + } + // Allocate memory for the merged data map. + merged_data_map_ = + method_allocator->allocateInstance(); + if (merged_data_map_ == nullptr) { + return Error::MemoryAllocationFailed; + } + new (merged_data_map_) internal::MergedDataMap(std::move(merged.get())); + named_data_map = merged_data_map_; + } else if (external_data_map) { + named_data_map = external_data_map; + } else if (pte_data_map.ok()) { + named_data_map = pte_data_map.get(); } // n_delegate_ counts the number of successfully-initialized delegates for @@ -1680,6 +1694,10 @@ Method::~Method() { for (const auto i : c10::irange(n_external_constants_)) { external_constants_[i].buffer.~FreeableBuffer(); } + // Free the MergedDataMap. + if (merged_data_map_ != nullptr) { + merged_data_map_->~MergedDataMap(); + } // All other fields are trivially destructible. } } // namespace ET_RUNTIME_NAMESPACE diff --git a/runtime/executor/method.h b/runtime/executor/method.h index 99a6aea439f..3ab38134332 100644 --- a/runtime/executor/method.h +++ b/runtime/executor/method.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -76,6 +77,7 @@ class Method final { delegates_(rhs.delegates_), n_chains_(rhs.n_chains_), chains_(rhs.chains_), + merged_data_map_(std::move(rhs.merged_data_map_)), external_constants_(rhs.external_constants_), n_external_constants_(rhs.n_external_constants_), init_state_(rhs.init_state_) { @@ -85,6 +87,8 @@ class Method final { rhs.values_ = nullptr; rhs.n_delegate_ = 0; rhs.delegates_ = nullptr; + + rhs.merged_data_map_ = nullptr; rhs.n_external_constants_ = 0; rhs.external_constants_ = nullptr; @@ -314,6 +318,7 @@ class Method final { delegates_(nullptr), n_chains_(0), chains_(nullptr), + merged_data_map_(nullptr), external_constants_(nullptr), n_external_constants_(0), init_state_(InitializationState::Uninitialized) {} @@ -364,6 +369,7 @@ class Method final { size_t n_chains_; Chain* chains_; + internal::MergedDataMap* merged_data_map_; NamedData* external_constants_; size_t n_external_constants_ = 0; diff --git a/runtime/executor/targets.bzl b/runtime/executor/targets.bzl index 649b2c13cc1..424cc3e147b 100644 --- a/runtime/executor/targets.bzl +++ b/runtime/executor/targets.bzl @@ -69,6 +69,16 @@ def define_common_targets(): exported_preprocessor_flags = [] if runtime.is_oss else ["-DEXECUTORCH_INTERNAL_FLATBUFFERS=1"], ) + runtime.cxx_library( + name = "merged_data_map" + aten_suffix, + exported_headers = [ + "merged_data_map.h", + ], + exported_deps = [ + "//executorch/runtime/core:named_data_map" + aten_suffix, + ], + ) + runtime.cxx_library( name = "program" + aten_suffix, exported_deps = [ @@ -107,6 +117,7 @@ def define_common_targets(): exported_deps = [ ":memory_manager", ":pte_data_map" + aten_suffix, + ":merged_data_map" + aten_suffix, "//executorch/runtime/backend:interface" + aten_suffix, "//executorch/runtime/core:core", "//executorch/runtime/core:named_data_map" + aten_suffix, diff --git a/runtime/executor/test/merged_data_map_test.cpp b/runtime/executor/test/merged_data_map_test.cpp new file mode 100644 index 00000000000..c9d1d510b97 --- /dev/null +++ b/runtime/executor/test/merged_data_map_test.cpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; +using executorch::extension::FileDataLoader; +using executorch::extension::FlatTensorDataMap; +using executorch::runtime::DataLoader; +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::NamedDataMap; +using executorch::runtime::Result; +using executorch::runtime::TensorLayout; +using executorch::runtime::internal::MergedDataMap; + +class MergedDataMapTest : public ::testing::Test { + protected: + void load_flat_tensor_data_map(const char* path, const char* module_name) { + Result loader = FileDataLoader::from(path); + ASSERT_EQ(loader.error(), Error::Ok); + loaders_.insert( + {module_name, + std::make_unique(std::move(loader.get()))}); + + Result data_map = + FlatTensorDataMap::load(loaders_[module_name].get()); + EXPECT_EQ(data_map.error(), Error::Ok); + + data_maps_.insert( + {module_name, + std::make_unique(std::move(data_map.get()))}); + } + + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + + // Load FlatTensor data maps. + // The eager addmul and linear models are defined at: + // //executorch/test/models/export_program.py + load_flat_tensor_data_map( + std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"), "addmul"); + load_flat_tensor_data_map( + std::getenv("ET_MODULE_LINEAR_DATA_PATH"), "linear"); + } + + private: + // Must outlive data_maps_, but tests shouldn't need to touch it. + std::unordered_map> loaders_; + + protected: + std::unordered_map> data_maps_; +}; + +// Check that two tensor layouts are equivalent. +void check_tensor_layout(TensorLayout& layout1, TensorLayout& layout2) { + EXPECT_EQ(layout1.scalar_type(), layout2.scalar_type()); + EXPECT_EQ(layout1.nbytes(), layout2.nbytes()); + EXPECT_EQ(layout1.sizes().size(), layout2.sizes().size()); + for (size_t i = 0; i < layout1.sizes().size(); i++) { + EXPECT_EQ(layout1.sizes()[i], layout2.sizes()[i]); + } + EXPECT_EQ(layout1.dim_order().size(), layout2.dim_order().size()); + for (size_t i = 0; i < layout1.dim_order().size(); i++) { + EXPECT_EQ(layout1.dim_order()[i], layout2.dim_order()[i]); + } +} + +// Given that ndm is part of merged, check that all the API calls on ndm produce +// the same results as merged. +void compare_ndm_api_calls( + const NamedDataMap* ndm, + const NamedDataMap* merged) { + uint32_t num_keys = ndm->get_num_keys().get(); + for (uint32_t i = 0; i < num_keys; i++) { + auto key = ndm->get_key(i).get(); + + // Compare get_tensor_layout. + auto ndm_meta = ndm->get_tensor_layout(key).get(); + auto merged_meta = merged->get_tensor_layout(key).get(); + check_tensor_layout(ndm_meta, merged_meta); + + // Coompare get_data. + auto ndm_data = ndm->get_data(key); + auto merged_data = merged->get_data(key); + EXPECT_EQ(ndm_data.get().size(), merged_data.get().size()); + for (size_t j = 0; j < ndm_meta.nbytes(); j++) { + EXPECT_EQ( + ((uint8_t*)ndm_data.get().data())[j], + ((uint8_t*)merged_data.get().data())[j]); + } + ndm_data->Free(); + merged_data->Free(); + } +} + +TEST_F(MergedDataMapTest, LoadNullDataMap) { + Result merged_map = MergedDataMap::load(nullptr, nullptr); + EXPECT_EQ(merged_map.error(), Error::InvalidArgument); +} + +TEST_F(MergedDataMapTest, LoadMultipleDataMaps) { + Result merged_map = MergedDataMap::load( + data_maps_["addmul"].get(), data_maps_["linear"].get()); + EXPECT_EQ(merged_map.error(), Error::Ok); +} + +TEST_F(MergedDataMapTest, LoadDuplicateDataMapsFail) { + Result merged_map = MergedDataMap::load( + data_maps_["addmul"].get(), data_maps_["addmul"].get()); + EXPECT_EQ(merged_map.error(), Error::InvalidArgument); +} + +TEST_F(MergedDataMapTest, CheckDataMapContents) { + Result merged_map = MergedDataMap::load( + data_maps_["addmul"].get(), data_maps_["linear"].get()); + EXPECT_EQ(merged_map.error(), Error::Ok); + + // Num keys. + size_t addmul_num_keys = data_maps_["addmul"]->get_num_keys().get(); + size_t linear_num_keys = data_maps_["linear"]->get_num_keys().get(); + EXPECT_EQ( + merged_map->get_num_keys().get(), addmul_num_keys + linear_num_keys); + + // Load data into is not implemented for the merged data map. + void* memory_block = malloc(10); + ASSERT_EQ( + Error::NotImplemented, merged_map->load_data_into("a", memory_block, 10)); + free(memory_block); + + // API calls produce equivalent results. + compare_ndm_api_calls(data_maps_["addmul"].get(), &merged_map.get()); + compare_ndm_api_calls(data_maps_["linear"].get(), &merged_map.get()); +} diff --git a/runtime/executor/test/targets.bzl b/runtime/executor/test/targets.bzl index 39ff0668d5d..1174b01f42b 100644 --- a/runtime/executor/test/targets.bzl +++ b/runtime/executor/test/targets.bzl @@ -125,6 +125,7 @@ def define_common_targets(is_fbcode = False): "ET_MODULE_STATEFUL_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleStateful.pte])", "ET_MODULE_ADD_MUL_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.pte])", "ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])", + "ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])", } runtime.cxx_test( @@ -142,6 +143,19 @@ def define_common_targets(is_fbcode = False): env = modules_env, ) + runtime.cxx_test( + name = "merged_data_map_test", + srcs = [ + "merged_data_map_test.cpp", + ], + deps = [ + "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/flat_tensor:flat_tensor_data_map", + "//executorch/runtime/executor:merged_data_map", + ], + env = modules_env, + ) + runtime.cxx_test( name = "method_test", srcs = [ @@ -149,6 +163,7 @@ def define_common_targets(is_fbcode = False): ], deps = [ ":managed_memory_manager", + "//executorch/runtime/executor:merged_data_map", "//executorch/runtime/executor:program", "//executorch/extension/data_loader:file_data_loader", "//executorch/extension/flat_tensor:flat_tensor_data_map",