From d5a52956f1cf3fe11f1922ab320ca3dbb8c10c32 Mon Sep 17 00:00:00 2001 From: lucylq Date: Thu, 2 Oct 2025 19:14:20 -0700 Subject: [PATCH] Introduce public MergedDataMap Add public merged data map. Module can use this to resolve multiple named data maps. Differential Revision: [D83527299](https://our.internmc.facebook.com/intern/diff/D83527299/) [ghstack-poisoned] --- extension/named_data_map/TARGETS | 8 + extension/named_data_map/merged_data_map.cpp | 117 +++++++++++ extension/named_data_map/merged_data_map.h | 107 ++++++++++ extension/named_data_map/targets.bzl | 21 ++ extension/named_data_map/test/TARGETS | 8 + .../test/merged_data_map_test.cpp | 187 ++++++++++++++++++ extension/named_data_map/test/targets.bzl | 27 +++ 7 files changed, 475 insertions(+) create mode 100644 extension/named_data_map/TARGETS create mode 100644 extension/named_data_map/merged_data_map.cpp create mode 100644 extension/named_data_map/merged_data_map.h create mode 100644 extension/named_data_map/targets.bzl create mode 100644 extension/named_data_map/test/TARGETS create mode 100644 extension/named_data_map/test/merged_data_map_test.cpp create mode 100644 extension/named_data_map/test/targets.bzl diff --git a/extension/named_data_map/TARGETS b/extension/named_data_map/TARGETS new file mode 100644 index 00000000000..2341af9282f --- /dev/null +++ b/extension/named_data_map/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/extension/named_data_map/merged_data_map.cpp b/extension/named_data_map/merged_data_map.cpp new file mode 100644 index 00000000000..678a465156b --- /dev/null +++ b/extension/named_data_map/merged_data_map.cpp @@ -0,0 +1,117 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +using executorch::aten::string_view; +using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap; +using executorch::ET_RUNTIME_NAMESPACE::TensorLayout; +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::Result; + +namespace executorch { +namespace extension { + +/*static*/ Result MergedDataMap::load( + std::vector named_data_maps) { + std::vector valid_data_maps; + for (size_t i = 0; i < named_data_maps.size(); i++) { + if (named_data_maps[i] != nullptr && + named_data_maps[i]->get_num_keys().get() > 0) { + valid_data_maps.push_back(named_data_maps[i]); + } + } + ET_CHECK_OR_RETURN_ERROR( + valid_data_maps.size() > 0, + InvalidArgument, + "No non-empty named data maps provided to merge"); + + // Check for duplicate keys. + std::unordered_map key_to_map_index; + for (uint32_t i = 0; i < valid_data_maps.size(); i++) { + const auto cur_map = valid_data_maps[i]; + uint32_t num_keys = cur_map->get_num_keys().get(); + for (uint32_t j = 0; j < num_keys; ++j) { + const auto cur_key = cur_map->get_key(j).get(); + ET_CHECK_OR_RETURN_ERROR( + key_to_map_index.find(cur_key) == key_to_map_index.end(), + InvalidArgument, + "Duplicate key %s in named data maps at index %u and %u", + cur_key, + key_to_map_index.at(cur_key), + i); + key_to_map_index[cur_key] = i; + } + } + return MergedDataMap(std::move(valid_data_maps), std::move(key_to_map_index)); +} + +ET_NODISCARD Result MergedDataMap::get_tensor_layout( + string_view key) const { + ET_CHECK_OR_RETURN_ERROR( + key_to_map_index_.find(key.data()) != key_to_map_index_.end(), + NotFound, + "Key %s not found in named data maps", + key.data()); + + return named_data_maps_.at(key_to_map_index_.at(key.data())) + ->get_tensor_layout(key); +} + +ET_NODISCARD +Result MergedDataMap::get_data(string_view key) const { + ET_CHECK_OR_RETURN_ERROR( + key_to_map_index_.find(key.data()) != key_to_map_index_.end(), + NotFound, + "Key %s not found in named data maps", + key.data()); + return named_data_maps_.at(key_to_map_index_.at(key.data()))->get_data(key); +} + +ET_NODISCARD Error MergedDataMap::load_data_into( + string_view key, + void* buffer, + size_t size) const { + ET_CHECK_OR_RETURN_ERROR( + key_to_map_index_.find(key.data()) != key_to_map_index_.end(), + NotFound, + "Key %s not found in named data maps", + key.data()); + return named_data_maps_.at(key_to_map_index_.at(key.data())) + ->load_data_into(key, buffer, size); +} + +ET_NODISCARD Result MergedDataMap::get_num_keys() const { + return key_to_map_index_.size(); +} + +ET_NODISCARD Result MergedDataMap::get_key(uint32_t index) const { + uint32_t total_num_keys = get_num_keys().get(); + ET_CHECK_OR_RETURN_ERROR( + index >= 0 && index < total_num_keys, + InvalidArgument, + "Index %u out of range of size %u", + index, + total_num_keys); + for (size_t i = 0; i < named_data_maps_.size(); i++) { + auto num_keys = named_data_maps_[i]->get_num_keys().get(); + if (index < num_keys) { + return named_data_maps_[i]->get_key(index); + } + index -= num_keys; + } + // Shouldn't reach here. + return Error::Internal; +} +} // namespace extension +} // namespace executorch diff --git a/extension/named_data_map/merged_data_map.h b/extension/named_data_map/merged_data_map.h new file mode 100644 index 00000000000..90a28b0380b --- /dev/null +++ b/extension/named_data_map/merged_data_map.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace executorch { +namespace extension { +/** + * A NamedDataMap implementation that wraps other NamedDataMaps. + */ +class MergedDataMap final + : public executorch::ET_RUNTIME_NAMESPACE::NamedDataMap { + public: + /** + * Creates a new NamedDataMap that takes in other data maps. + * + * @param[in] data_maps vector of NamedDataMap pointers to merge. + * Note: the data maps must outlive the MergedDataMap instance. + */ + static executorch::runtime::Result load( + std::vector + named_data_maps); + + /** + * Retrieve the tensor_layout for the specified key. + * + * @param[in] key The name of the tensor to get metadata on. + * + * @return Error::NotFound if the key is not present. + */ + ET_NODISCARD + executorch::runtime::Result< + const executorch::ET_RUNTIME_NAMESPACE::TensorLayout> + get_tensor_layout(executorch::aten::string_view key) const override; + + /** + * Retrieve read-only data for the specified key. + * + * @param[in] key The name of the tensor to get data on. + * + * @return error if the key is not present or data cannot be loaded. + */ + ET_NODISCARD + executorch::runtime::Result get_data( + executorch::aten::string_view key) const override; + + /** + * Loads the data of the specified tensor into the provided buffer. + * + * @param[in] key The name of the tensor to get the data of. + * @param[in] buffer The buffer to load data into. Must point to at least + * `size` bytes of memory. + * @param[in] size The number of bytes to load. + * + * @returns an Error indicating if the load was successful. + */ + ET_NODISCARD executorch::runtime::Error load_data_into( + executorch::aten::string_view key, + void* buffer, + size_t size) const override; + + /** + * @returns The number of keys in the map. + */ + ET_NODISCARD executorch::runtime::Result get_num_keys() + const override; + /** + * @returns The key at the specified index, error if index out of bounds. + */ + ET_NODISCARD executorch::runtime::Result get_key( + uint32_t index) const override; + + MergedDataMap(MergedDataMap&&) noexcept = default; + + ~MergedDataMap() override = default; + + private: + MergedDataMap( + std::vector + named_data_maps, + std::unordered_map key_to_map_index) + : named_data_maps_(std::move(named_data_maps)), + key_to_map_index_(std::move(key_to_map_index)) {} + + // Not copyable or assignable. + MergedDataMap(const MergedDataMap& rhs) = delete; + MergedDataMap& operator=(MergedDataMap&& rhs) noexcept = delete; + MergedDataMap& operator=(const MergedDataMap& rhs) = delete; + + std::vector + named_data_maps_; + + // Map from key to index in the named_data_maps_ vector. + std::unordered_map key_to_map_index_; +}; + +} // namespace extension +} // namespace executorch diff --git a/extension/named_data_map/targets.bzl b/extension/named_data_map/targets.bzl new file mode 100644 index 00000000000..1eff02565c2 --- /dev/null +++ b/extension/named_data_map/targets.bzl @@ -0,0 +1,21 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_aten_mode_options", "runtime") + +def define_common_targets(): + for aten_mode in get_aten_mode_options(): + aten_suffix = "_aten" if aten_mode else "" + runtime.cxx_library( + name = "merged_data_map" + aten_suffix, + srcs = [ + "merged_data_map.cpp", + ], + exported_headers = [ + "merged_data_map.h", + ], + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/runtime/core:named_data_map", + "//executorch/runtime/core:core", + ], + ) diff --git a/extension/named_data_map/test/TARGETS b/extension/named_data_map/test/TARGETS new file mode 100644 index 00000000000..883ab644309 --- /dev/null +++ b/extension/named_data_map/test/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets(is_fbcode=True) diff --git a/extension/named_data_map/test/merged_data_map_test.cpp b/extension/named_data_map/test/merged_data_map_test.cpp new file mode 100644 index 00000000000..14f2b57c04f --- /dev/null +++ b/extension/named_data_map/test/merged_data_map_test.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; +using executorch::extension::FileDataLoader; +using executorch::extension::FlatTensorDataMap; +using executorch::extension::MergedDataMap; +using executorch::runtime::DataLoader; +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::NamedDataMap; +using executorch::runtime::Result; +using executorch::runtime::TensorLayout; + +class MergedDataMapTest : public ::testing::Test { + protected: + void load_flat_tensor_data_map(const char* path, const char* module_name) { + Result loader = FileDataLoader::from(path); + ASSERT_EQ(loader.error(), Error::Ok); + loaders_.insert( + {module_name, + std::make_unique(std::move(loader.get()))}); + + Result data_map = + FlatTensorDataMap::load(loaders_[module_name].get()); + EXPECT_EQ(data_map.error(), Error::Ok); + + data_maps_.insert( + {module_name, + std::make_unique(std::move(data_map.get()))}); + } + + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + + // Load FlatTensor data maps. + // The eager addmul and linear models are defined at: + // //executorch/test/models/export_program.py + load_flat_tensor_data_map( + std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"), "addmul"); + load_flat_tensor_data_map( + std::getenv("ET_MODULE_LINEAR_DATA_PATH"), "linear"); + load_flat_tensor_data_map( + std::getenv("ET_MODULE_SIMPLE_TRAIN_DATA_PATH"), "simple_train"); + } + + private: + // Must outlive data_maps_, but tests shouldn't need to touch it. + std::unordered_map> loaders_; + + protected: + std::unordered_map> data_maps_; +}; + +// Check that two tensor layouts are equivalent. +void check_tensor_layout(TensorLayout& layout1, TensorLayout& layout2) { + EXPECT_EQ(layout1.scalar_type(), layout2.scalar_type()); + EXPECT_EQ(layout1.nbytes(), layout2.nbytes()); + EXPECT_EQ(layout1.sizes().size(), layout2.sizes().size()); + for (size_t i = 0; i < layout1.sizes().size(); i++) { + EXPECT_EQ(layout1.sizes()[i], layout2.sizes()[i]); + } + EXPECT_EQ(layout1.dim_order().size(), layout2.dim_order().size()); + for (size_t i = 0; i < layout1.dim_order().size(); i++) { + EXPECT_EQ(layout1.dim_order()[i], layout2.dim_order()[i]); + } +} + +// Given that ndm is part of merged, check that all the API calls on ndm produce +// the same results as merged. +void compare_ndm_api_calls( + const NamedDataMap* ndm, + const NamedDataMap* merged) { + uint32_t num_keys = ndm->get_num_keys().get(); + for (uint32_t i = 0; i < num_keys; i++) { + auto key = ndm->get_key(i).get(); + + // Compare get_tensor_layout. + auto ndm_meta = ndm->get_tensor_layout(key).get(); + auto merged_meta = merged->get_tensor_layout(key).get(); + check_tensor_layout(ndm_meta, merged_meta); + + // Compare get_data. + auto ndm_data = ndm->get_data(key); + auto merged_data = merged->get_data(key); + EXPECT_EQ(ndm_data.get().size(), merged_data.get().size()); + for (size_t j = 0; j < ndm_meta.nbytes(); j++) { + EXPECT_EQ( + ((uint8_t*)ndm_data.get().data())[j], + ((uint8_t*)merged_data.get().data())[j]); + } + ndm_data->Free(); + merged_data->Free(); + + // Compare load_into. + auto nbytes = ndm_meta.nbytes(); + void* ndm_buffer = malloc(nbytes); + auto ndm_load_into = ndm->load_data_into(key, ndm_buffer, nbytes); + EXPECT_EQ(ndm_load_into, Error::Ok); + void* merged_buffer = malloc(nbytes); + auto merged_load_into = merged->load_data_into(key, merged_buffer, nbytes); + EXPECT_EQ(merged_load_into, Error::Ok); + for (size_t j = 0; j < ndm_meta.nbytes(); j++) { + EXPECT_EQ(((uint8_t*)merged_buffer)[j], ((uint8_t*)merged_buffer)[j]); + } + free(ndm_buffer); + free(merged_buffer); + } +} + +TEST_F(MergedDataMapTest, LoadNullDataMap) { + Result merged_map = MergedDataMap::load({nullptr, nullptr}); + EXPECT_EQ(merged_map.error(), Error::InvalidArgument); +} + +TEST_F(MergedDataMapTest, LoadMultipleDataMaps) { + std::vector ndms = { + data_maps_["addmul"].get(), data_maps_["linear"].get()}; + Result merged_map = MergedDataMap::load(ndms); + EXPECT_EQ(merged_map.error(), Error::Ok); + + std::vector ndms2 = { + data_maps_["addmul"].get(), data_maps_["simple_train"].get()}; + Result merged_map2 = MergedDataMap::load(ndms2); + EXPECT_EQ(merged_map2.error(), Error::Ok); +} + +TEST_F(MergedDataMapTest, LoadSingleDataMap) { + std::vector ndms = {data_maps_["addmul"].get(), nullptr}; + Result merged_map = MergedDataMap::load(ndms); + EXPECT_EQ(merged_map.error(), Error::Ok); + + // Num keys. + EXPECT_EQ( + merged_map->get_num_keys().get(), + data_maps_["addmul"]->get_num_keys().get()); + + // API calls produce equivalent results. + compare_ndm_api_calls(data_maps_["addmul"].get(), &merged_map.get()); +} + +TEST_F(MergedDataMapTest, LoadDuplicateDataMapsFail) { + std::vector ndms = { + data_maps_["addmul"].get(), data_maps_["addmul"].get()}; + Result merged_map = MergedDataMap::load(ndms); + EXPECT_EQ(merged_map.error(), Error::InvalidArgument); + + std::vector ndms2 = { + data_maps_["addmul"].get(), + data_maps_["linear"].get(), + data_maps_["simple_train"].get()}; + Result merged_map2 = MergedDataMap::load(ndms); + EXPECT_EQ(merged_map2.error(), Error::InvalidArgument); +} + +TEST_F(MergedDataMapTest, CheckDataMapContents) { + std::vector ndms = { + data_maps_["addmul"].get(), data_maps_["linear"].get()}; + Result merged_map = MergedDataMap::load(ndms); + EXPECT_EQ(merged_map.error(), Error::Ok); + + // Num keys. + size_t addmul_num_keys = data_maps_["addmul"]->get_num_keys().get(); + size_t linear_num_keys = data_maps_["linear"]->get_num_keys().get(); + EXPECT_EQ( + merged_map->get_num_keys().get(), addmul_num_keys + linear_num_keys); + + // API calls produce equivalent results. + compare_ndm_api_calls(data_maps_["addmul"].get(), &merged_map.get()); + compare_ndm_api_calls(data_maps_["linear"].get(), &merged_map.get()); +} diff --git a/extension/named_data_map/test/targets.bzl b/extension/named_data_map/test/targets.bzl new file mode 100644 index 00000000000..3d543b22cf0 --- /dev/null +++ b/extension/named_data_map/test/targets.bzl @@ -0,0 +1,27 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(is_fbcode=False): + if not runtime.is_oss and is_fbcode: + modules_env = { + # The tests use this var to find the program file to load. This uses + # an fbcode target path because the authoring/export tools + # intentionally don't work in xplat (since they're host-only tools). + "ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])", + "ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])", + "ET_MODULE_SIMPLE_TRAIN_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleSimpleTrain.ptd])", + } + + runtime.cxx_test( + name = "merged_data_map_test", + srcs = [ + "merged_data_map_test.cpp", + ], + deps = [ + "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/flat_tensor:flat_tensor_data_map", + "//executorch/extension/named_data_map:merged_data_map", + "//executorch/runtime/core:named_data_map", + "//executorch/runtime/core/exec_aten:lib", + ], + env = modules_env, + )