diff --git a/backends/xnnpack/test/runtime/test_xnn_data_separation.cpp b/backends/xnnpack/test/runtime/test_xnn_data_separation.cpp index 342e3478e0f..18953da1ec7 100644 --- a/backends/xnnpack/test/runtime/test_xnn_data_separation.cpp +++ b/backends/xnnpack/test/runtime/test_xnn_data_separation.cpp @@ -87,15 +87,15 @@ TEST_F(DataSeparationTest, TestExternalData) { // Check that accessing keys out of bounds fails. EXPECT_EQ(data_map->get_key(2).error(), Error::InvalidArgument); - // Linear.weight + // Linear.bias Result data0 = data_map->get_data(key0.get()); EXPECT_EQ(data0.error(), Error::Ok); - EXPECT_EQ(data0.get().size(), 36); // 3*3*4 (3*3 matrix, 4 bytes per float) + EXPECT_EQ(data0.get().size(), 12); // 3*4 (3 vector, 4 bytes per float) - // Linear.bias + // Linear.weight Result data1 = data_map->get_data(key1.get()); EXPECT_EQ(data1.error(), Error::Ok); - EXPECT_EQ(data1.get().size(), 12); // 3*4 (3 vector, 4 bytes per float) + EXPECT_EQ(data1.get().size(), 36); // 3*3*4 (3*3 matrix, 4 bytes per float) // Check that accessing non-existent data fails. Result data2 = data_map->get_data("nonexistent"); diff --git a/extension/flat_tensor/flat_tensor_data_map.cpp b/extension/flat_tensor/flat_tensor_data_map.cpp index 3a69dc8b92c..24e4b51306e 100644 --- a/extension/flat_tensor/flat_tensor_data_map.cpp +++ b/extension/flat_tensor/flat_tensor_data_map.cpp @@ -19,15 +19,15 @@ #include #include +using executorch::aten::ScalarType; +using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap; +using executorch::ET_RUNTIME_NAMESPACE::TensorLayout; +using executorch::runtime::DataLoader; using executorch::runtime::Error; using executorch::runtime::FreeableBuffer; using executorch::runtime::Result; using executorch::runtime::Span; -using executorch::aten::ScalarType; -using executorch::ET_RUNTIME_NAMESPACE::TensorLayout; -using executorch::runtime::DataLoader; - namespace executorch { namespace extension { @@ -103,82 +103,111 @@ Result create_tensor_layout( ET_NODISCARD Result FlatTensorDataMap::get_tensor_layout( executorch::aten::string_view key) const { - Result named_data = get_named_data( - key, - flat_tensor_->named_data(), - flat_tensor_->segments(), - header_.segment_base_offset + header_.segment_data_size); - if (!named_data.ok()) { + if (key_to_map_index_.find(key.data()) == key_to_map_index_.end()) { + return Error::NotFound; + } + auto index = key_to_map_index_.at(key.data()); + if (index == -1) { + Result named_data = + get_named_data( + key, + flat_tensor_->named_data(), + flat_tensor_->segments(), + header_.segment_base_offset + header_.segment_data_size); + if (named_data.ok()) { + return create_tensor_layout(named_data.get()->tensor_layout()); + } return named_data.error(); + } else { + return merged_maps_[index]->get_tensor_layout(key); } - return create_tensor_layout(named_data.get()->tensor_layout()); } ET_NODISCARD Result FlatTensorDataMap::get_data( executorch::aten::string_view key) const { - Result named_data = get_named_data( - key, - flat_tensor_->named_data(), - flat_tensor_->segments(), - header_.segment_base_offset + header_.segment_data_size); - if (!named_data.ok()) { + if (key_to_map_index_.find(key.data()) == key_to_map_index_.end()) { + return Error::NotFound; + } + auto index = key_to_map_index_.at(key.data()); + if (index == -1) { + Result named_data = + get_named_data( + key, + flat_tensor_->named_data(), + flat_tensor_->segments(), + header_.segment_base_offset + header_.segment_data_size); + if (named_data.ok()) { + uint32_t segment_index = named_data.get()->segment_index(); + uint64_t segment_offset = + flat_tensor_->segments()->Get(segment_index)->offset(); + uint64_t segment_size = + flat_tensor_->segments()->Get(segment_index)->size(); + + return loader_->load( + /*offset=*/header_.segment_base_offset + segment_offset, + segment_size, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); + } return named_data.error(); + } else { + return merged_maps_[index]->get_data(key); } - - uint32_t segment_index = named_data.get()->segment_index(); - uint64_t segment_offset = - flat_tensor_->segments()->Get(segment_index)->offset(); - uint64_t segment_size = flat_tensor_->segments()->Get(segment_index)->size(); - - return loader_->load( - /*offset=*/header_.segment_base_offset + segment_offset, - segment_size, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); } ET_NODISCARD Error FlatTensorDataMap::load_data_into( ET_UNUSED executorch::aten::string_view key, ET_UNUSED void* buffer, ET_UNUSED size_t size) const { - Result named_data = get_named_data( - key, - flat_tensor_->named_data(), - flat_tensor_->segments(), - header_.segment_base_offset + header_.segment_data_size); - if (!named_data.ok()) { - return named_data.error(); + if (key_to_map_index_.find(key.data()) == key_to_map_index_.end()) { + return Error::NotFound; } + auto index = key_to_map_index_.at(key.data()); + if (index == -1) { + Result named_data = + get_named_data( + key, + flat_tensor_->named_data(), + flat_tensor_->segments(), + header_.segment_base_offset + header_.segment_data_size); + if (!named_data.ok()) { + return named_data.error(); + } - uint32_t segment_index = named_data.get()->segment_index(); - uint64_t segment_offset = - flat_tensor_->segments()->Get(segment_index)->offset(); + uint32_t segment_index = named_data.get()->segment_index(); + uint64_t segment_offset = + flat_tensor_->segments()->Get(segment_index)->offset(); - Result tensor_layout = - create_tensor_layout(named_data.get()->tensor_layout()); + Result tensor_layout = + create_tensor_layout(named_data.get()->tensor_layout()); - if (!tensor_layout.ok()) { - return tensor_layout.error(); - } + if (!tensor_layout.ok()) { + return tensor_layout.error(); + } - ET_CHECK_OR_RETURN_ERROR( - size <= tensor_layout.get().nbytes(), - InvalidArgument, - "Buffer size %zu is smaller than tensor size %zu", - size, - tensor_layout.get().nbytes()); - - // Load mutable data. - DataLoader::SegmentInfo info = DataLoader::SegmentInfo( - DataLoader::SegmentInfo::Type::Mutable, 0, nullptr); - return loader_->load_into( - header_.segment_base_offset + segment_offset, - tensor_layout.get().nbytes(), - info, - buffer); + ET_CHECK_OR_RETURN_ERROR( + size <= tensor_layout.get().nbytes(), + InvalidArgument, + "Buffer size %zu is smaller than tensor size %zu", + size, + tensor_layout.get().nbytes()); + + // Load mutable data. + DataLoader::SegmentInfo info = DataLoader::SegmentInfo( + DataLoader::SegmentInfo::Type::Mutable, 0, nullptr); + return loader_->load_into( + header_.segment_base_offset + segment_offset, + tensor_layout.get().nbytes(), + info, + buffer); + } else { + return merged_maps_[index]->load_data_into(key, buffer, size); + } } ET_NODISCARD Result FlatTensorDataMap::get_num_keys() const { - return flat_tensor_->named_data()->size(); + // Guaranteed safe, as the segment_index is a uint32_t, which means + // that there can't be more than uint32_t keys. + return static_cast(key_to_map_index_.size()); } ET_NODISCARD Result FlatTensorDataMap::get_key( @@ -190,7 +219,40 @@ ET_NODISCARD Result FlatTensorDataMap::get_key( "Index %u out of range of size %u", index, num_keys); - return flat_tensor_->named_data()->Get(index)->key()->c_str(); + + uint32_t current_index = 0; + for (const auto& pair : key_to_map_index_) { + if (current_index == index) { + return pair.first.c_str(); + } + current_index++; + } + return Error::NotFound; +} + +ET_NODISCARD Error FlatTensorDataMap::merge(const NamedDataMap* other) { + ET_CHECK_OR_RETURN_ERROR( + other != nullptr, InvalidArgument, "Merge error: other is nullptr."); + + // Check if any duplicate keys exist. + uint32_t num_keys = other->get_num_keys().get(); + + for (uint32_t i = 0; i < num_keys; i++) { + const char* key = other->get_key(i).get(); + ET_CHECK_OR_RETURN_ERROR( + key_to_map_index_.find(key) == key_to_map_index_.end(), + InvalidArgument, + "Merge error: key %s already exists in the named_data_map.", + key); + } + // Place keys into the map. + for (uint32_t i = 0; i < num_keys; i++) { + const char* key = other->get_key(i).get(); + key_to_map_index_[key] = static_cast(merged_maps_.size()); + } + + merged_maps_.push_back(other); + return Error::Ok; } /* static */ Result FlatTensorDataMap::load( @@ -261,8 +323,18 @@ ET_NODISCARD Result FlatTensorDataMap::get_key( InvalidExternalData, "FlatTensor segments is nullptr, malformed PTD file."); + // Add keys to the map. + std::unordered_map key_to_map_index; + for (int i = 0; i < flat_tensor->named_data()->size(); i++) { + const auto* named_data = flat_tensor->named_data()->Get(i); + key_to_map_index[named_data->key()->c_str()] = -1; + } return FlatTensorDataMap( - fh.get(), std::move(flat_tensor_data.get()), flat_tensor, loader); + fh.get(), + std::move(flat_tensor_data.get()), + flat_tensor, + loader, + std::move(key_to_map_index)); } } // namespace extension diff --git a/extension/flat_tensor/flat_tensor_data_map.h b/extension/flat_tensor/flat_tensor_data_map.h index 751e312f7ef..006766d66f8 100644 --- a/extension/flat_tensor/flat_tensor_data_map.h +++ b/extension/flat_tensor/flat_tensor_data_map.h @@ -8,17 +8,18 @@ #pragma once -#include - #include +#include #include #include #include #include #include +#include #include +#include // Forward declare flatbuffer types. This is a public header and must not // include the generated flatbuffer header. @@ -94,6 +95,17 @@ class FlatTensorDataMap final ET_NODISCARD executorch::runtime::Result get_key( uint32_t index) const override; + /** + * Merge a named_data_map into the current one. + * @param[in] other The named_data_map to merge. + * @return Error indicating if the merge was successful or not. + * + * Note: The FlatTensorDataMap does not perform a deep copy; it holds a + * reference to other, so other must outlive the FlatTensorDataMap instance. + */ + ET_NODISCARD executorch::runtime::Error merge( + const NamedDataMap* other) override; + FlatTensorDataMap(FlatTensorDataMap&&) noexcept = default; ~FlatTensorDataMap() override = default; @@ -103,11 +115,14 @@ class FlatTensorDataMap final const FlatTensorHeader& header, executorch::runtime::FreeableBuffer&& flat_tensor_data, const flat_tensor_flatbuffer::FlatTensor* flat_tensor, - executorch::runtime::DataLoader* loader) + executorch::runtime::DataLoader* loader, + std::unordered_map key_to_map_index) : header_(header), flat_tensor_data_(std::move(flat_tensor_data)), flat_tensor_(flat_tensor), - loader_(loader) {} + loader_(loader), + key_to_map_index_(std::move(key_to_map_index)), + merged_maps_({}) {} // Not copyable or assignable. FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete; @@ -125,6 +140,13 @@ class FlatTensorDataMap final // Data loader, used to load segment data. executorch::runtime::DataLoader* loader_; + + // Cache of keys to data map index. + // index=-1 is used for the flat_tensor data map. + std::unordered_map key_to_map_index_; + + // Other NamedDataMaps. + std::vector merged_maps_; }; } // namespace extension diff --git a/extension/flat_tensor/test/CMakeLists.txt b/extension/flat_tensor/test/CMakeLists.txt index c3296dc61f3..8910732369b 100644 --- a/extension/flat_tensor/test/CMakeLists.txt +++ b/extension/flat_tensor/test/CMakeLists.txt @@ -21,21 +21,23 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) add_custom_command( OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte" "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd" + "${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte" + "${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd" COMMAND - ${PYTHON_EXECUTABLE} -m test.models.export_program --modules "ModuleAddMul" + ${PYTHON_EXECUTABLE} -m test.models.export_program --modules "ModuleAddMul,ModuleLinear" --external-constants --outdir "${CMAKE_CURRENT_BINARY_DIR}" 2> /dev/null WORKING_DIRECTORY ${EXECUTORCH_ROOT} ) add_custom_target( extension_flat_tensor_test_resources - DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte" - "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd" + DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd" + "${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd" ) set(test_env - "ET_MODULE_ADD_MUL_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte" "ET_MODULE_ADD_MUL_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd" + "ET_MODULE_LINEAR_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd" ) set(_test_srcs flat_tensor_data_map_test.cpp flat_tensor_header_test.cpp) diff --git a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp index 5a94b47b954..5136ba5b67c 100644 --- a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp +++ b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp @@ -35,25 +35,33 @@ class FlatTensorDataMapTest : public ::testing::Test { // Load data map. The eager linear model is defined at: // //executorch/test/models/linear_model.py - const char* path = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"); - Result loader = FileDataLoader::from(path); - ASSERT_EQ(loader.error(), Error::Ok); + const char* add_mul_path = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"); + Result add_mul_loader = FileDataLoader::from(add_mul_path); + ASSERT_EQ(add_mul_loader.error(), Error::Ok); - data_map_loader_ = - std::make_unique(std::move(loader.get())); + add_mul_loader_ = + std::make_unique(std::move(add_mul_loader.get())); + + const char* linear_path = std::getenv("ET_MODULE_LINEAR_DATA_PATH"); + Result linear_loader = FileDataLoader::from(linear_path); + ASSERT_EQ(linear_loader.error(), Error::Ok); + + linear_loader_ = + std::make_unique(std::move(linear_loader.get())); } - std::unique_ptr data_map_loader_; + std::unique_ptr add_mul_loader_; + std::unique_ptr linear_loader_; }; -TEST_F(FlatTensorDataMapTest, LoadFlatTensorDataMap) { +TEST_F(FlatTensorDataMapTest, LoadDataMap) { Result data_map = - FlatTensorDataMap::load(data_map_loader_.get()); - EXPECT_EQ(data_map.error(), Error::Ok); + FlatTensorDataMap::load(add_mul_loader_.get()); + ASSERT_EQ(data_map.error(), Error::Ok); } -TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) { +TEST_F(FlatTensorDataMapTest, GetTensorLayout) { Result data_map = - FlatTensorDataMap::load(data_map_loader_.get()); + FlatTensorDataMap::load(add_mul_loader_.get()); EXPECT_EQ(data_map.error(), Error::Ok); // Check tensor layouts are correct. @@ -93,10 +101,10 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) { EXPECT_EQ(const_c_res.error(), Error::NotFound); } -TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) { +TEST_F(FlatTensorDataMapTest, GetData) { Result data_map = - FlatTensorDataMap::load(data_map_loader_.get()); - EXPECT_EQ(data_map.error(), Error::Ok); + FlatTensorDataMap::load(add_mul_loader_.get()); + ASSERT_EQ(data_map.error(), Error::Ok); // Check tensor data sizes are correct. Result data_a_res = data_map->get_data("a"); @@ -114,10 +122,10 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) { EXPECT_EQ(data_c_res.error(), Error::NotFound); } -TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) { +TEST_F(FlatTensorDataMapTest, Keys) { Result data_map = - FlatTensorDataMap::load(data_map_loader_.get()); - EXPECT_EQ(data_map.error(), Error::Ok); + FlatTensorDataMap::load(add_mul_loader_.get()); + ASSERT_EQ(data_map.error(), Error::Ok); // Check num tensors is 2. Result num_tensors_res = data_map->get_num_keys(); @@ -127,30 +135,30 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) { // Check get_key returns the correct keys. Result key0_res = data_map->get_key(0); ASSERT_EQ(Error::Ok, key0_res.error()); - EXPECT_EQ(strcmp(key0_res.get(), "a"), 0); + EXPECT_EQ(strcmp(key0_res.get(), "b"), 0); Result key1_res = data_map->get_key(1); ASSERT_EQ(Error::Ok, key1_res.error()); - EXPECT_EQ(strcmp(key1_res.get(), "b"), 0); + EXPECT_EQ(strcmp(key1_res.get(), "a"), 0); // Check get_key fails when out of bounds. Result key2_res = data_map->get_key(2); EXPECT_EQ(key2_res.error(), Error::InvalidArgument); } -TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) { +TEST_F(FlatTensorDataMapTest, LoadInto) { Result data_map = - FlatTensorDataMap::load(data_map_loader_.get()); - EXPECT_EQ(data_map.error(), Error::Ok); + FlatTensorDataMap::load(add_mul_loader_.get()); + ASSERT_EQ(data_map.error(), Error::Ok); - // get the metadata - auto meta_data_res = data_map->get_tensor_layout("a"); - ASSERT_EQ(meta_data_res.error(), Error::Ok); + // Get tensor layout. + auto tensor_layout = data_map->get_tensor_layout("a"); + ASSERT_EQ(tensor_layout.error(), Error::Ok); - // get data blob - void* data = malloc(meta_data_res->nbytes()); + // Get data blob. + void* data = malloc(tensor_layout->nbytes()); auto load_into_error = - data_map->load_data_into("a", data, meta_data_res->nbytes()); + data_map->load_data_into("a", data, tensor_layout->nbytes()); ASSERT_EQ(load_into_error, Error::Ok); // Check tensor data is correct. @@ -160,3 +168,50 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) { } free(data); } + +TEST_F(FlatTensorDataMapTest, MergeDuplicates) { + Result data_map = + FlatTensorDataMap::load(add_mul_loader_.get()); + ASSERT_EQ(data_map.error(), Error::Ok); + + Result data_map2 = + FlatTensorDataMap::load(add_mul_loader_.get()); + ASSERT_EQ(data_map2.error(), Error::Ok); + + Error error = data_map->merge(&data_map2.get()); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(FlatTensorDataMapTest, Merge) { + Result data_map = + FlatTensorDataMap::load(add_mul_loader_.get()); + ASSERT_EQ(data_map.error(), Error::Ok); + + Result linear_data_map = + FlatTensorDataMap::load(linear_loader_.get()); + ASSERT_EQ(linear_data_map.error(), Error::Ok); + + Error error = data_map->merge(&linear_data_map.get()); + ASSERT_EQ(error, Error::Ok); + + // Check num tensors is 4. + Result num_keys = data_map->get_num_keys(); + ASSERT_EQ(Error::Ok, num_keys.error()); + EXPECT_EQ(num_keys.get(), 4); + + // Check data map has the new linear keys. + Result linear_weight = + data_map->get_tensor_layout("linear.weight"); + EXPECT_EQ(Error::Ok, linear_weight.error()); + + Result linear_bias = + data_map->get_tensor_layout("linear.bias"); + EXPECT_EQ(Error::Ok, linear_bias.error()); + + // Check data map still has the add_mul keys. + Result a = data_map->get_tensor_layout("a"); + EXPECT_EQ(Error::Ok, a.error()); + + Result b = data_map->get_tensor_layout("b"); + EXPECT_EQ(Error::Ok, b.error()); +} diff --git a/extension/flat_tensor/test/targets.bzl b/extension/flat_tensor/test/targets.bzl index 4d798cc1a7c..75544da11a1 100644 --- a/extension/flat_tensor/test/targets.bzl +++ b/extension/flat_tensor/test/targets.bzl @@ -35,8 +35,8 @@ def define_common_targets(is_fbcode=False): # The tests use this var to find the program file to load. This uses # an fbcode target path because the authoring/export tools # intentionally don't work in xplat (since they're host-only tools). - "ET_MODULE_ADD_MUL_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.pte])", "ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])", + "ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])", } runtime.cxx_test( diff --git a/runtime/core/named_data_map.h b/runtime/core/named_data_map.h index 7503f0b2979..e3fd0da3961 100644 --- a/runtime/core/named_data_map.h +++ b/runtime/core/named_data_map.h @@ -77,6 +77,14 @@ class ET_EXPERIMENTAL NamedDataMap { * pointer is only valid for the lifetime of the DataMap. */ ET_NODISCARD virtual Result get_key(uint32_t index) const = 0; + + /** + * Merge a named_data_map into the current one. + * + * @param other The named_data_map to merge. + * @return Error indicating if the merge was successful or not. + */ + ET_NODISCARD virtual Error merge(const NamedDataMap* other) = 0; }; } // namespace ET_RUNTIME_NAMESPACE diff --git a/runtime/executor/pte_data_map.h b/runtime/executor/pte_data_map.h index b4b46a6b541..c9fefadf0f4 100644 --- a/runtime/executor/pte_data_map.h +++ b/runtime/executor/pte_data_map.h @@ -114,6 +114,13 @@ class PteDataMap final : public NamedDataMap { */ ET_NODISCARD Result get_key(uint32_t index) const override; + /** + * The PteDataMap does not implement merge. + */ + ET_NODISCARD Error merge(ET_UNUSED const NamedDataMap* other) override { + return Error::NotImplemented; + } + // Moveable, to be compatible with Result. PteDataMap(PteDataMap&&) noexcept = default; ~PteDataMap() override = default;