@@ -125,11 +125,26 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
125125
126126runtime::Error Module::load_method (
127127 const std::string& method_name,
128- torch::executor::EventTracer* event_tracer) {
128+ torch::executor::EventTracer* event_tracer,
129+ const std::string& data_map_path) {
129130 if (!is_method_loaded (method_name)) {
130131 ET_CHECK_OK_OR_RETURN_ERROR (load ());
131132
132133 MethodHolder method_holder;
134+
135+ // If we have a .ptd load it.
136+ const runtime::NamedDataMap* named_data_map = nullptr ;
137+ if (!data_map_path.empty ()) {
138+ auto data_map_data_loader =
139+ ET_UNWRAP_UNIQUE (FileDataLoader::from (data_map_path.c_str ()));
140+ auto data_map =
141+ ET_UNWRAP_UNIQUE (executorch::extension::FlatTensorDataMap::load (
142+ data_map_data_loader.get ()));
143+ method_holder.data_map_loader = std::move (data_map_data_loader);
144+ method_holder.data_map = std::move (data_map);
145+ }
146+ named_data_map = method_holder.data_map .get ();
147+
133148 const auto method_metadata =
134149 ET_UNWRAP (program_->method_meta (method_name.c_str ()));
135150 const auto planned_buffersCount =
@@ -155,7 +170,8 @@ runtime::Error Module::load_method(
155170 method_holder.method = ET_UNWRAP_UNIQUE (program_->load_method (
156171 method_name.c_str (),
157172 method_holder.memory_manager .get (),
158- event_tracer ? event_tracer : this ->event_tracer ()));
173+ event_tracer ? event_tracer : this ->event_tracer (),
174+ named_data_map));
159175 method_holder.inputs .resize (method_holder.method ->inputs_size ());
160176 methods_.emplace (method_name, std::move (method_holder));
161177 }
0 commit comments