diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 4b82dbf4954..4b1c30ae6b5 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -75,9 +75,7 @@ Module::Module( load_mode_(load_mode), memory_allocator_(std::make_unique()), temp_allocator_(std::make_unique()), - event_tracer_(std::move(event_tracer)), - data_map_loader_(nullptr), - data_map_(nullptr) { + event_tracer_(std::move(event_tracer)) { runtime::runtime_init(); } @@ -87,13 +85,27 @@ Module::Module( const LoadMode load_mode, std::unique_ptr event_tracer) : file_path_(file_path), - data_map_path_(data_map_path), load_mode_(load_mode), memory_allocator_(std::make_unique()), temp_allocator_(std::make_unique()), - event_tracer_(std::move(event_tracer)), - data_map_loader_(nullptr), - data_map_(nullptr) { + event_tracer_(std::move(event_tracer)) { + if (!data_map_path.empty()) { + data_files_.push_back(data_map_path); + } + runtime::runtime_init(); +} + +Module::Module( + const std::string& file_path, + std::vector data_files, + const LoadMode load_mode, + std::unique_ptr event_tracer) + : file_path_(file_path), + data_files_(std::move(data_files)), + load_mode_(load_mode), + memory_allocator_(std::make_unique()), + temp_allocator_(std::make_unique()), + event_tracer_(std::move(event_tracer)) { runtime::runtime_init(); } @@ -110,9 +122,10 @@ Module::Module( temp_allocator_( temp_allocator ? std::move(temp_allocator) : std::make_unique()), - event_tracer_(std::move(event_tracer)), - data_map_loader_(std::move(data_map_loader)), - data_map_(nullptr) { + event_tracer_(std::move(event_tracer)) { + if (data_map_loader) { + data_map_loaders_.push_back(std::move(data_map_loader)); + } runtime::runtime_init(); } @@ -129,9 +142,10 @@ Module::Module( temp_allocator_( temp_allocator ? std::move(temp_allocator) : std::make_unique()), - event_tracer_(std::move(event_tracer)), - data_map_loader_(std::move(data_map_loader)), - data_map_(nullptr) { + event_tracer_(std::move(event_tracer)) { + if (data_map_loader) { + data_map_loaders_.push_back(std::move(data_map_loader)); + } runtime::runtime_init(); } @@ -140,14 +154,27 @@ runtime::Error Module::load(const Program::Verification verification) { if (!data_loader_) { data_loader_ = ET_UNWRAP(make_data_loader(file_path_, load_mode_)); } - if (!data_map_path_.empty()) { - data_map_loader_ = - ET_UNWRAP(make_data_loader(data_map_path_, load_mode_)); + if (data_files_.size() > 0) { + ET_CHECK_OR_RETURN_ERROR( + data_files_.size() == 1, + NotImplemented, + "Multiple named data map paths are not supported yet."); + for (const auto& data_file : data_files_) { + data_map_loaders_.push_back( + ET_UNWRAP(make_data_loader(data_file, load_mode_))); + } } - if (data_map_loader_) { - data_map_ = - ET_UNWRAP_UNIQUE(FlatTensorDataMap::load(data_map_loader_.get())); + + if (data_map_loaders_.size() > 0) { + ET_CHECK_OR_RETURN_ERROR( + data_map_loaders_.size() == 1 && merged_data_map_ == nullptr, + NotImplemented, + "Multiple named data map loaders are not supported yet."); + // TODO(lfq): support multiple named data map loaders. + merged_data_map_ = + ET_UNWRAP_UNIQUE(FlatTensorDataMap::load(data_map_loaders_[0].get())); } + auto program = ET_UNWRAP_UNIQUE(Program::load(data_loader_.get(), verification)); program_ = std::shared_ptr( @@ -209,7 +236,7 @@ runtime::Error Module::load_method( method_name.c_str(), method_holder.memory_manager.get(), event_tracer ? event_tracer : this->event_tracer(), - data_map_.get())); + merged_data_map_.get())); methods_.emplace(method_name, std::move(method_holder)); } return runtime::Error::Ok; diff --git a/extension/module/module.h b/extension/module/module.h index 58ff3ada720..207de768991 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -70,7 +70,7 @@ class Module { * memory locking behavior. * * @param[in] file_path The path to the ExecuTorch program file to load. - * @param[in] data_map_path The path to a .ptd file + * @param[in] data_map_path The path to a .ptd file. * @param[in] load_mode The loading mode to use. * @param[in] event_tracer A EventTracer used for tracking and logging events. */ @@ -80,6 +80,21 @@ class Module { const LoadMode load_mode = LoadMode::File, std::unique_ptr event_tracer = nullptr); + /** + * Constructs an instance by loading a program from a file with specified + * memory locking behavior. + * + * @param[in] file_path The path to the ExecuTorch program file to load. + * @param[in] data_files The path to one or more .ptd file/s. + * @param[in] load_mode The loading mode to use. + * @param[in] event_tracer A EventTracer used for tracking and logging events. + */ + explicit Module( + const std::string& file_path, + std::vector data_files, + const LoadMode load_mode = LoadMode::File, + std::unique_ptr event_tracer = nullptr); + /** * Constructs an instance with the provided data loader and memory allocator. * @@ -614,15 +629,16 @@ class Module { }; std::string file_path_; - std::string data_map_path_; + std::vector data_files_; LoadMode load_mode_{LoadMode::File}; std::shared_ptr program_; std::unique_ptr data_loader_; std::unique_ptr memory_allocator_; std::unique_ptr temp_allocator_; std::unique_ptr event_tracer_; - std::unique_ptr data_map_loader_; - std::unique_ptr data_map_; + std::vector> data_map_loaders_; + std::vector> named_data_maps_; + std::unique_ptr merged_data_map_; ET_DEPRECATED std::vector debug_buffer_; protected: diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index 1c9fc5628ba..6f7e8a44558 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -530,3 +530,18 @@ TEST_F(ModuleTest, TestPTD) { auto tensor = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 2.f}); ASSERT_EQ(module.forward(tensor).error(), Error::Ok); } + +TEST_F(ModuleTest, TestPTD_Multiple) { + std::vector data_files = {add_mul_data_path_}; + Module module(add_mul_path_, data_files); + + ASSERT_EQ(module.load_method("forward"), Error::Ok); + + auto tensor = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 2.f}); + ASSERT_EQ(module.forward(tensor).error(), Error::Ok); + + // Confirm that the data_file is not std::move'd away. + ASSERT_EQ(std::strcmp(data_files[0].c_str(), add_mul_data_path_.c_str()), 0); + + // TODO(lfq): add test when merge capability is supported. +}