1010
1111#include < executorch/extension/data_loader/file_data_loader.h>
1212#include < executorch/extension/data_loader/mmap_data_loader.h>
13+ #include < executorch/extension/flat_tensor/flat_tensor_data_map.h>
1314#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
1415#include < executorch/runtime/platform/runtime.h>
1516
3637namespace executorch {
3738namespace extension {
3839
40+ namespace {
41+ runtime::Result<std::unique_ptr<runtime::DataLoader>> load_file (
42+ const std::string& file_path,
43+ Module::LoadMode mode) {
44+ std::unique_ptr<runtime::DataLoader> res = nullptr ;
45+ switch (mode) {
46+ case Module::LoadMode::File:
47+ res =
48+ ET_UNWRAP_UNIQUE (FileDataLoader::from (file_path.c_str ()));
49+ break ;
50+ case Module::LoadMode::Mmap:
51+ res = ET_UNWRAP_UNIQUE (MmapDataLoader::from (
52+ file_path.c_str (), MmapDataLoader::MlockConfig::NoMlock));
53+ break ;
54+ case Module::LoadMode::MmapUseMlock:
55+ res =
56+ ET_UNWRAP_UNIQUE (MmapDataLoader::from (file_path.c_str ()));
57+ break ;
58+ case Module::LoadMode::MmapUseMlockIgnoreErrors:
59+ res = ET_UNWRAP_UNIQUE (MmapDataLoader::from (
60+ file_path.c_str (),
61+ MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
62+ break ;
63+ }
64+ return res;
65+ }
66+ }
67+
3968Module::Module (
4069 const std::string& file_path,
4170 const LoadMode load_mode,
4271 std::unique_ptr<runtime::EventTracer> event_tracer)
4372 : file_path_(file_path),
73+ data_map_path_ (" " ),
4474 load_mode_(load_mode),
4575 memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
4676 temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
47- event_tracer_(std::move(event_tracer)) {
77+ event_tracer_(std::move(event_tracer)),
78+ data_map_loader_(nullptr ),
79+ data_map_(nullptr ) {
4880 runtime::runtime_init ();
4981}
5082
83+ Module::Module (
84+ const std::string& file_path,
85+ const std::string& data_map_path,
86+ const LoadMode load_mode,
87+ std::unique_ptr<runtime::EventTracer> event_tracer)
88+ : file_path_(file_path),
89+ data_map_path_(data_map_path),
90+ load_mode_(load_mode),
91+ memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
92+ temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
93+ event_tracer_(std::move(event_tracer)),
94+ data_map_loader_(nullptr ),
95+ data_map_(nullptr ) {
96+ runtime::runtime_init ();
97+ }
98+
99+
100+
51101Module::Module (
52102 std::unique_ptr<runtime::DataLoader> data_loader,
53103 std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
54104 std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
55- std::unique_ptr<runtime::EventTracer> event_tracer)
56- : data_loader_(std::move(data_loader)),
105+ std::unique_ptr<runtime::EventTracer> event_tracer,
106+ std::unique_ptr<runtime::NamedDataMap> data_map)
107+ :
108+ file_path_(" " ),
109+ data_map_path_(" " ),
110+ data_loader_(std::move(data_loader)),
57111 memory_allocator_(
58112 memory_allocator ? std::move(memory_allocator)
59113 : std::make_unique<MallocMemoryAllocator>()),
60114 temp_allocator_(
61115 temp_allocator ? std::move(temp_allocator)
62116 : std::make_unique<MallocMemoryAllocator>()),
63- event_tracer_(std::move(event_tracer)) {
117+ event_tracer_(std::move(event_tracer)),
118+ data_map_loader_(nullptr ),
119+ data_map_(std::move(data_map)) {
64120 runtime::runtime_init ();
65121}
66122
67123Module::Module (
68124 std::shared_ptr<runtime::Program> program,
69125 std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
70126 std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
71- std::unique_ptr<runtime::EventTracer> event_tracer)
72- : program_(std::move(program)),
127+ std::unique_ptr<runtime::EventTracer> event_tracer,
128+ std::unique_ptr<runtime::NamedDataMap> data_map)
129+ :
130+ file_path_(" " ),
131+ data_map_path_(" " ),
132+ program_(std::move(program)),
73133 memory_allocator_(
74134 memory_allocator ? std::move(memory_allocator)
75135 : std::make_unique<MallocMemoryAllocator>()),
76136 temp_allocator_(
77137 temp_allocator ? std::move(temp_allocator)
78138 : std::make_unique<MallocMemoryAllocator>()),
79- event_tracer_(std::move(event_tracer)) {
139+ event_tracer_(std::move(event_tracer)),
140+ data_map_loader_(nullptr ),
141+ data_map_(std::move(data_map)) {
80142 runtime::runtime_init ();
81143}
82144
83145runtime::Error Module::load (const runtime::Program::Verification verification) {
84146 if (!is_loaded ()) {
147+ // Load the program
85148 if (!data_loader_) {
86- switch (load_mode_) {
87- case LoadMode::File:
88- data_loader_ =
89- ET_UNWRAP_UNIQUE (FileDataLoader::from (file_path_.c_str ()));
90- break ;
91- case LoadMode::Mmap:
92- data_loader_ = ET_UNWRAP_UNIQUE (MmapDataLoader::from (
93- file_path_.c_str (), MmapDataLoader::MlockConfig::NoMlock));
94- break ;
95- case LoadMode::MmapUseMlock:
96- data_loader_ =
97- ET_UNWRAP_UNIQUE (MmapDataLoader::from (file_path_.c_str ()));
98- break ;
99- case LoadMode::MmapUseMlockIgnoreErrors:
100- data_loader_ = ET_UNWRAP_UNIQUE (MmapDataLoader::from (
101- file_path_.c_str (),
102- MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
103- break ;
149+ auto res = load_file (file_path_, load_mode_);
150+ if (!res.ok ()) {
151+ return res.error ();
104152 }
105- };
153+ data_loader_ = std::move (res.get ());
154+ }
155+ // If a .ptd path was given load it.
156+ if (data_map_path_ != " " ){
157+ auto res = load_file (data_map_path_, load_mode_);
158+ if (!res.ok ()) {
159+ return res.error ();
160+ }
161+ data_map_loader_ = std::move (res.get ());
162+ }
163+ // If we have a .ptd loader, then load the map.
164+ if (data_map_loader_) {
165+ data_map_ = ET_UNWRAP_UNIQUE (FlatTensorDataMap::load (data_map_loader_.get ()));
166+ }
167+ // else: either the map itself was provided or we have no data map, either way no work to do.
106168 auto program = ET_UNWRAP_UNIQUE (
107169 runtime::Program::load (data_loader_.get (), verification));
108170 program_ = std::shared_ptr<runtime::Program>(
@@ -130,6 +192,7 @@ runtime::Error Module::load_method(
130192 ET_CHECK_OK_OR_RETURN_ERROR (load ());
131193
132194 MethodHolder method_holder;
195+
133196 const auto method_metadata =
134197 ET_UNWRAP (program_->method_meta (method_name.c_str ()));
135198 const auto planned_buffersCount =
@@ -155,10 +218,22 @@ runtime::Error Module::load_method(
155218 method_holder.method = ET_UNWRAP_UNIQUE (program_->load_method (
156219 method_name.c_str (),
157220 method_holder.memory_manager .get (),
158- event_tracer ? event_tracer : this ->event_tracer ()));
221+ event_tracer ? event_tracer : this ->event_tracer (),
222+ data_map_.get ()));
159223 method_holder.inputs .resize (method_holder.method ->inputs_size ());
160224 methods_.emplace (method_name, std::move (method_holder));
161225 }
226+ return runtime::Error::Ok;
227+ }
228+
229+ runtime::Error Module::load_method (
230+ const std::string& method_name,
231+ const std::string& data_map_path,
232+ torch::executor::EventTracer* event_tracer) {
233+ if (!is_method_loaded (method_name)) {
234+ ET_CHECK_OK_OR_RETURN_ERROR (load ());
235+ return load_method (method_name, event_tracer);
236+ }
162237 return runtime::Error::Ok;
163238}
164239
0 commit comments