diff --git a/extension/module/CMakeLists.txt b/extension/module/CMakeLists.txt index d887d873ab7..8fb2be9a677 100644 --- a/extension/module/CMakeLists.txt +++ b/extension/module/CMakeLists.txt @@ -29,7 +29,7 @@ else() endif() target_link_libraries( extension_module PRIVATE executorch_core extension_data_loader - extension_flat_tensor + extension_flat_tensor extension_named_data_map ) target_include_directories( extension_module PUBLIC ${_common_include_directories} @@ -42,8 +42,9 @@ target_compile_options( # after cleaning up CMake targets. add_library(extension_module_static STATIC ${_extension_module__srcs}) target_link_libraries( - extension_module_static PRIVATE executorch_core extension_data_loader - extension_flat_tensor + extension_module_static + PRIVATE executorch_core extension_data_loader extension_flat_tensor + extension_named_data_map ) target_include_directories( extension_module_static PUBLIC ${_common_include_directories} diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 4b1c30ae6b5..9de77bcbc79 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include /** @@ -38,6 +39,7 @@ namespace executorch { namespace extension { namespace ET_MODULE_NAMESPACE { +using ET_MERGED_DATA_MAP_NAMESPACE::MergedDataMap; using ET_RUNTIME_NAMESPACE::MethodMeta; using ET_RUNTIME_NAMESPACE::Program; @@ -155,10 +157,6 @@ runtime::Error Module::load(const Program::Verification verification) { data_loader_ = ET_UNWRAP(make_data_loader(file_path_, load_mode_)); } if (data_files_.size() > 0) { - ET_CHECK_OR_RETURN_ERROR( - data_files_.size() == 1, - NotImplemented, - "Multiple named data map paths are not supported yet."); for (const auto& data_file : data_files_) { data_map_loaders_.push_back( ET_UNWRAP(make_data_loader(data_file, load_mode_))); @@ -166,13 +164,20 @@ runtime::Error Module::load(const Program::Verification verification) { } if (data_map_loaders_.size() > 0) { - ET_CHECK_OR_RETURN_ERROR( - data_map_loaders_.size() == 1 && merged_data_map_ == nullptr, - NotImplemented, - "Multiple named data map loaders are not supported yet."); - // TODO(lfq): support multiple named data map loaders. - merged_data_map_ = - ET_UNWRAP_UNIQUE(FlatTensorDataMap::load(data_map_loaders_[0].get())); + for (auto i = 0; i < data_map_loaders_.size(); ++i) { + named_data_maps_.push_back(ET_UNWRAP_UNIQUE( + FlatTensorDataMap::load(data_map_loaders_[i].get()))); + } + + // Extract raw pointers from unique_ptrs to pass to MergedDataMap::load() + std::vector raw_data_maps; + raw_data_maps.reserve(named_data_maps_.size()); + for (const auto& data_map : named_data_maps_) { + raw_data_maps.push_back(data_map.get()); + } + merged_data_map_ = ET_UNWRAP_UNIQUE( + MergedDataMap::load(runtime::Span( + raw_data_maps.data(), raw_data_maps.size()))); } auto program = diff --git a/extension/module/targets.bzl b/extension/module/targets.bzl index 3e449da5e14..0db909ce053 100644 --- a/extension/module/targets.bzl +++ b/extension/module/targets.bzl @@ -26,6 +26,7 @@ def define_common_targets(): "//executorch/extension/data_loader:file_data_loader", "//executorch/extension/data_loader:mmap_data_loader", "//executorch/extension/flat_tensor:flat_tensor_data_map" + aten_suffix, + "//executorch/extension/named_data_map:merged_data_map" + aten_suffix, ], exported_deps = [ "//executorch/runtime/executor:program_no_prim_ops" + aten_suffix, diff --git a/extension/module/test/CMakeLists.txt b/extension/module/test/CMakeLists.txt index 1c4358dd73e..54ace17557f 100644 --- a/extension/module/test/CMakeLists.txt +++ b/extension/module/test/CMakeLists.txt @@ -23,11 +23,14 @@ add_custom_command( OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/ModuleAdd.pte" "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte" "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd" + "${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte" + "${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd" COMMAND ${PYTHON_EXECUTABLE} -m test.models.export_program --modules "ModuleAdd" --outdir "${CMAKE_CURRENT_BINARY_DIR}" COMMAND - ${PYTHON_EXECUTABLE} -m test.models.export_program --modules "ModuleAddMul" - --external-constants --outdir "${CMAKE_CURRENT_BINARY_DIR}" + ${PYTHON_EXECUTABLE} -m test.models.export_program --modules + "ModuleAddMul,ModuleLinear" --external-constants --outdir + "${CMAKE_CURRENT_BINARY_DIR}" WORKING_DIRECTORY ${EXECUTORCH_ROOT} ) @@ -36,12 +39,16 @@ add_custom_target( DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/ModuleAdd.pte" "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte" "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd" + "${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte" + "${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd" ) set(test_env "ET_MODULE_ADD_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAdd.pte" "ET_MODULE_ADD_MUL_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte" "ET_MODULE_ADD_MUL_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd" + "ET_MODULE_LINEAR_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte" + "ET_MODULE_LINEAR_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd" ) et_cxx_test( diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index 6f7e8a44558..27332503cad 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -26,11 +26,15 @@ class ModuleTest : public ::testing::Test { model_path_ = std::getenv("ET_MODULE_ADD_PATH"); add_mul_path_ = std::getenv("ET_MODULE_ADD_MUL_PROGRAM_PATH"); add_mul_data_path_ = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"); + linear_path_ = std::getenv("ET_MODULE_LINEAR_PROGRAM_PATH"); + linear_data_path_ = std::getenv("ET_MODULE_LINEAR_DATA_PATH"); } static inline std::string model_path_; static inline std::string add_mul_path_; static inline std::string add_mul_data_path_; + static inline std::string linear_path_; + static inline std::string linear_data_path_; }; TEST_F(ModuleTest, TestLoad) { @@ -532,16 +536,21 @@ TEST_F(ModuleTest, TestPTD) { } TEST_F(ModuleTest, TestPTD_Multiple) { - std::vector data_files = {add_mul_data_path_}; - Module module(add_mul_path_, data_files); - - ASSERT_EQ(module.load_method("forward"), Error::Ok); + std::vector data_files = {add_mul_data_path_, linear_data_path_}; + // Create module with add mul. + Module module_add_mul(add_mul_path_, data_files); + ASSERT_EQ(module_add_mul.load_method("forward"), Error::Ok); auto tensor = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 2.f}); - ASSERT_EQ(module.forward(tensor).error(), Error::Ok); + ASSERT_EQ(module_add_mul.forward(tensor).error(), Error::Ok); // Confirm that the data_file is not std::move'd away. ASSERT_EQ(std::strcmp(data_files[0].c_str(), add_mul_data_path_.c_str()), 0); + ASSERT_EQ(std::strcmp(data_files[1].c_str(), linear_data_path_.c_str()), 0); - // TODO(lfq): add test when merge capability is supported. + // Create module with linear. + Module module_linear(linear_path_, data_files); + ASSERT_EQ(module_linear.load_method("forward"), Error::Ok); + auto tensor2 = make_tensor_ptr({3}, {2.f, 3.f, 4.f}); + ASSERT_EQ(module_linear.forward(tensor2).error(), Error::Ok); } diff --git a/extension/module/test/targets.bzl b/extension/module/test/targets.bzl index d1aa73f6789..da7f1cc91bd 100644 --- a/extension/module/test/targets.bzl +++ b/extension/module/test/targets.bzl @@ -19,6 +19,8 @@ def define_common_targets(is_fbcode=False): "ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])", "ET_MODULE_ADD_MUL_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.pte])", "ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])", + "ET_MODULE_LINEAR_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])", + "ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])", "ET_MODULE_SHARED_STATE": "$(location fbcode//executorch/test/models:exported_programs[ModuleSharedState.pte])", } diff --git a/extension/named_data_map/merged_data_map.cpp b/extension/named_data_map/merged_data_map.cpp index b42701c7587..2d1bb7d6158 100644 --- a/extension/named_data_map/merged_data_map.cpp +++ b/extension/named_data_map/merged_data_map.cpp @@ -21,7 +21,7 @@ using executorch::runtime::Result; using executorch::runtime::Span; namespace executorch::extension { - +namespace ET_MERGED_DATA_MAP_NAMESPACE { /*static*/ Result MergedDataMap::load( Span named_data_maps) { std::vector valid_data_maps; @@ -38,7 +38,7 @@ namespace executorch::extension { // Check for duplicate keys. std::unordered_map key_to_map_index; - for (auto i : c10::irange(valid_data_maps.size())) { + for (const uint32_t i : c10::irange(valid_data_maps.size())) { const auto cur_map = valid_data_maps[i]; uint32_t num_keys = cur_map->get_num_keys().get(); for (auto j : c10::irange(num_keys)) { @@ -47,7 +47,7 @@ namespace executorch::extension { ET_CHECK_OR_RETURN_ERROR( inserted, InvalidArgument, - "Duplicate key %s in named data maps at index %u and %lu", + "Duplicate key %s in named data maps at index %u and %" PRIu32, cur_key, it->second, i); @@ -114,4 +114,6 @@ ET_NODISCARD Result MergedDataMap::get_key(uint32_t index) const { // Shouldn't reach here. return Error::Internal; } + +} // namespace ET_MERGED_DATA_MAP_NAMESPACE } // namespace executorch::extension diff --git a/extension/named_data_map/merged_data_map.h b/extension/named_data_map/merged_data_map.h index 13415c0b59e..42490ec3d58 100644 --- a/extension/named_data_map/merged_data_map.h +++ b/extension/named_data_map/merged_data_map.h @@ -13,7 +13,15 @@ #include #include +#ifdef USE_ATEN_LIB +#define ET_MERGED_DATA_MAP_NAMESPACE merged_data_map::aten +#else // !USE_ATEN_LIB +#define ET_MERGED_DATA_MAP_NAMESPACE merged_data_map +#endif // USE_ATEN_LIB + namespace executorch::extension { + +namespace ET_MERGED_DATA_MAP_NAMESPACE { /** * A NamedDataMap implementation that wraps other NamedDataMaps. */ @@ -103,4 +111,5 @@ class MergedDataMap final std::unordered_map key_to_map_index_; }; +} // namespace ET_MERGED_DATA_MAP_NAMESPACE } // namespace executorch::extension diff --git a/extension/named_data_map/test/merged_data_map_test.cpp b/extension/named_data_map/test/merged_data_map_test.cpp index 4086855f439..ccfaaa0ec0e 100644 --- a/extension/named_data_map/test/merged_data_map_test.cpp +++ b/extension/named_data_map/test/merged_data_map_test.cpp @@ -23,7 +23,7 @@ using namespace ::testing; using executorch::extension::FileDataLoader; using executorch::extension::FlatTensorDataMap; -using executorch::extension::MergedDataMap; +using executorch::extension::merged_data_map::MergedDataMap; using executorch::runtime::DataLoader; using executorch::runtime::Error; using executorch::runtime::NamedDataMap; diff --git a/scripts/build_apple_frameworks.sh b/scripts/build_apple_frameworks.sh index 8ce2d68bab8..63fa4cf4545 100755 --- a/scripts/build_apple_frameworks.sh +++ b/scripts/build_apple_frameworks.sh @@ -31,6 +31,7 @@ libextension_apple.a,\ libextension_data_loader.a,\ libextension_flat_tensor.a,\ libextension_module.a,\ +libextension_named_data_map.a,\ libextension_tensor.a,\ :${FRAMEWORK_EXECUTORCH_HEADERS_DIR}:${FRAMEWORK_EXECUTORCH_MODULE_NAME}"