1010
1111#include < executorch/extension/data_loader/file_data_loader.h>
1212#include < executorch/extension/data_loader/mmap_data_loader.h>
13+ #include < executorch/extension/flat_tensor/flat_tensor_data_map.h>
1314#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
1415#include < executorch/runtime/platform/runtime.h>
1516
3637namespace executorch {
3738namespace extension {
3839
40+ namespace {
41+ runtime::Result<std::unique_ptr<runtime::DataLoader>> load_file (
42+ const std::string& file_path,
43+ Module::LoadMode mode) {
44+ std::unique_ptr<runtime::DataLoader> res = nullptr ;
45+ switch (mode) {
46+ case Module::LoadMode::File:
47+ res = ET_UNWRAP_UNIQUE (FileDataLoader::from (file_path.c_str ()));
48+ break ;
49+ case Module::LoadMode::Mmap:
50+ res = ET_UNWRAP_UNIQUE (MmapDataLoader::from (
51+ file_path.c_str (), MmapDataLoader::MlockConfig::NoMlock));
52+ break ;
53+ case Module::LoadMode::MmapUseMlock:
54+ res = ET_UNWRAP_UNIQUE (MmapDataLoader::from (file_path.c_str ()));
55+ break ;
56+ case Module::LoadMode::MmapUseMlockIgnoreErrors:
57+ res = ET_UNWRAP_UNIQUE (MmapDataLoader::from (
58+ file_path.c_str (),
59+ MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
60+ break ;
61+ }
62+ return res;
63+ }
64+ } // namespace
65+
66+ Module::Module (
67+ const std::string& file_path,
68+ const LoadMode load_mode,
69+ std::unique_ptr<runtime::EventTracer> event_tracer)
70+ : file_path_(file_path),
71+ load_mode_ (load_mode),
72+ memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
73+ temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
74+ event_tracer_(std::move(event_tracer)),
75+ data_map_loader_(nullptr ),
76+ data_map_(nullptr ) {
77+ runtime::runtime_init ();
78+ }
79+
3980Module::Module (
4081 const std::string& file_path,
82+ const std::string& data_map_path,
4183 const LoadMode load_mode,
4284 std::unique_ptr<runtime::EventTracer> event_tracer)
4385 : file_path_(file_path),
86+ data_map_path_(data_map_path),
4487 load_mode_(load_mode),
4588 memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
4689 temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
47- event_tracer_(std::move(event_tracer)) {
90+ event_tracer_(std::move(event_tracer)),
91+ data_map_loader_(nullptr ),
92+ data_map_(nullptr ) {
4893 runtime::runtime_init ();
4994}
5095
5196Module::Module (
5297 std::unique_ptr<runtime::DataLoader> data_loader,
5398 std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
5499 std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
55- std::unique_ptr<runtime::EventTracer> event_tracer)
100+ std::unique_ptr<runtime::EventTracer> event_tracer,
101+ std::unique_ptr<runtime::DataLoader> data_map_loader)
56102 : data_loader_(std::move(data_loader)),
57103 memory_allocator_(
58104 memory_allocator ? std::move(memory_allocator)
59105 : std::make_unique<MallocMemoryAllocator>()),
60106 temp_allocator_(
61107 temp_allocator ? std::move(temp_allocator)
62108 : std::make_unique<MallocMemoryAllocator>()),
63- event_tracer_(std::move(event_tracer)) {
109+ event_tracer_(std::move(event_tracer)),
110+ data_map_loader_(std::move(data_map_loader)),
111+ data_map_(nullptr ) {
64112 runtime::runtime_init ();
65113}
66114
67115Module::Module (
68116 std::shared_ptr<runtime::Program> program,
69117 std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
70118 std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
71- std::unique_ptr<runtime::EventTracer> event_tracer)
119+ std::unique_ptr<runtime::EventTracer> event_tracer,
120+ std::unique_ptr<runtime::DataLoader> data_map_loader)
72121 : program_(std::move(program)),
73122 memory_allocator_(
74123 memory_allocator ? std::move(memory_allocator)
75124 : std::make_unique<MallocMemoryAllocator>()),
76125 temp_allocator_(
77126 temp_allocator ? std::move(temp_allocator)
78127 : std::make_unique<MallocMemoryAllocator>()),
79- event_tracer_(std::move(event_tracer)) {
128+ event_tracer_(std::move(event_tracer)),
129+ data_map_loader_(std::move(data_map_loader)),
130+ data_map_(nullptr ) {
80131 runtime::runtime_init ();
81132}
82133
83134runtime::Error Module::load (const runtime::Program::Verification verification) {
84135 if (!is_loaded ()) {
136+ // Load the program
85137 if (!data_loader_) {
86- switch (load_mode_) {
87- case LoadMode::File:
88- data_loader_ =
89- ET_UNWRAP_UNIQUE (FileDataLoader::from (file_path_.c_str ()));
90- break ;
91- case LoadMode::Mmap:
92- data_loader_ = ET_UNWRAP_UNIQUE (MmapDataLoader::from (
93- file_path_.c_str (), MmapDataLoader::MlockConfig::NoMlock));
94- break ;
95- case LoadMode::MmapUseMlock:
96- data_loader_ =
97- ET_UNWRAP_UNIQUE (MmapDataLoader::from (file_path_.c_str ()));
98- break ;
99- case LoadMode::MmapUseMlockIgnoreErrors:
100- data_loader_ = ET_UNWRAP_UNIQUE (MmapDataLoader::from (
101- file_path_.c_str (),
102- MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
103- break ;
138+ auto res = load_file (file_path_, load_mode_);
139+ if (!res.ok ()) {
140+ return res.error ();
104141 }
105- };
142+ data_loader_ = std::move (res.get ());
143+ }
144+ // If a .ptd path was given load it.
145+ if (data_map_path_ != " " ) {
146+ auto res = load_file (data_map_path_, load_mode_);
147+ if (!res.ok ()) {
148+ return res.error ();
149+ }
150+ data_map_loader_ = std::move (res.get ());
151+ }
152+ // If we have a .ptd loader, then load the map.
153+ if (data_map_loader_) {
154+ data_map_ =
155+ ET_UNWRAP_UNIQUE (FlatTensorDataMap::load (data_map_loader_.get ()));
156+ }
157+ // else: either the map itself was provided or we have no data map, either
158+ // way no work to do.
106159 auto program = ET_UNWRAP_UNIQUE (
107160 runtime::Program::load (data_loader_.get (), verification));
108161 program_ = std::shared_ptr<runtime::Program>(
@@ -130,6 +183,7 @@ runtime::Error Module::load_method(
130183 ET_CHECK_OK_OR_RETURN_ERROR (load ());
131184
132185 MethodHolder method_holder;
186+
133187 const auto method_metadata =
134188 ET_UNWRAP (program_->method_meta (method_name.c_str ()));
135189 const auto planned_buffersCount =
@@ -155,7 +209,8 @@ runtime::Error Module::load_method(
155209 method_holder.method = ET_UNWRAP_UNIQUE (program_->load_method (
156210 method_name.c_str (),
157211 method_holder.memory_manager .get (),
158- event_tracer ? event_tracer : this ->event_tracer ()));
212+ event_tracer ? event_tracer : this ->event_tracer (),
213+ data_map_.get ()));
159214 method_holder.inputs .resize (method_holder.method ->inputs_size ());
160215 methods_.emplace (method_name, std::move (method_holder));
161216 }
0 commit comments