@@ -173,27 +173,51 @@ inline std::unique_ptr<Module> load_module_from_buffer(
173173}
174174
175175inline std::unique_ptr<Module> load_module_from_file (
176- const std::string& path,
176+ const std::string& program_path,
177+ std::optional<const std::string>& data_map_path,
177178 bool enable_etdump,
178179 size_t debug_buffer_size,
179180 Program::Verification program_verification) {
180181 EXECUTORCH_SCOPE_PROF (" load_module_from_file" );
181182
182- Result<MmapDataLoader> res = MmapDataLoader::from (
183- path .c_str (), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
183+ Result<MmapDataLoader> program_loader_res = MmapDataLoader::from (
184+ program_path .c_str (), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
184185 THROW_IF_ERROR (
185- res .error (),
186+ program_loader_res .error (),
186187 " Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
187- path.c_str (),
188- static_cast <uint32_t >(res.error ()));
189-
190- auto loader = std::make_unique<MmapDataLoader>(std::move (res.get ()));
191- return std::make_unique<Module>(
192- std::move (loader),
193- nullptr , // memory_allocator
194- nullptr , // temp_allocator
195- enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr ,
196- nullptr ); // data_map_loader
188+ program_path.c_str (),
189+ static_cast <uint32_t >(program_loader_res.error ()));
190+ auto program_loader =
191+ std::make_unique<MmapDataLoader>(std::move (program_loader_res.get ()));
192+
193+ if (data_map_path.has_value ()) {
194+ Result<MmapDataLoader> data_map_loader_res = MmapDataLoader::from (
195+ data_map_path->c_str (),
196+ MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
197+ THROW_IF_ERROR (
198+ data_map_loader_res.error (),
199+ " Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
200+ data_map_path->c_str (),
201+ static_cast <uint32_t >(data_map_loader_res.error ()));
202+ auto data_map_loader =
203+ std::make_unique<MmapDataLoader>(std::move (data_map_loader_res.get ()));
204+
205+ return std::make_unique<Module>(
206+ std::move (program_loader),
207+ nullptr , // memory_allocator
208+ nullptr , // temp_allocator
209+ enable_etdump ? std::make_unique<torch::executor::ETDumpGen>()
210+ : nullptr ,
211+ std::move (data_map_loader)); // data_map_loader
212+ } else {
213+ return std::make_unique<Module>(
214+ std::move (program_loader),
215+ nullptr , // memory_allocator
216+ nullptr , // temp_allocator
217+ enable_etdump ? std::make_unique<torch::executor::ETDumpGen>()
218+ : nullptr ,
219+ nullptr ); // data_map_loader
220+ }
197221}
198222
199223inline py::list get_outputs_as_py_list (
@@ -495,13 +519,15 @@ struct PyModule final {
495519 program_verification)) {}
496520
497521 explicit PyModule (
498- const std::string& path,
522+ const std::string& program_path,
523+ std::optional<const std::string>& data_path,
499524 bool enable_etdump,
500525 size_t debug_buffer_size = 0 ,
501526 Program::Verification program_verification =
502527 Program::Verification::InternalConsistency)
503528 : module_(load_module_from_file(
504- path,
529+ program_path,
530+ data_path,
505531 enable_etdump,
506532 debug_buffer_size,
507533 program_verification)) {}
@@ -521,14 +547,20 @@ struct PyModule final {
521547 return std::make_unique<PyModule>(
522548 buffer, enable_etdump, debug_buffer_size, program_verification);
523549 }
550+
524551 static std::unique_ptr<PyModule> load_from_file (
525- const std::string& path,
552+ const std::string& program_path,
553+ std::optional<const std::string>& data_path,
526554 bool enable_etdump,
527555 size_t debug_buffer_size = 0 ,
528556 Program::Verification program_verification =
529557 Program::Verification::InternalConsistency) {
530558 return std::make_unique<PyModule>(
531- path, enable_etdump, debug_buffer_size, program_verification);
559+ program_path,
560+ data_path,
561+ enable_etdump,
562+ debug_buffer_size,
563+ program_verification);
532564 }
533565
534566 static std::unique_ptr<PyModule> load_from_bundled_program (
@@ -1301,7 +1333,8 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
13011333 m.def (
13021334 " _load_for_executorch" ,
13031335 PyModule::load_from_file,
1304- py::arg (" path" ),
1336+ py::arg (" program_path" ),
1337+ py::arg (" data_path" ) = std::nullopt ,
13051338 py::arg (" enable_etdump" ) = false ,
13061339 py::arg (" debug_buffer_size" ) = 0 ,
13071340 py::arg (" program_verification" ) =
0 commit comments