@@ -125,11 +125,25 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
125125
126126runtime::Error Module::load_method (
127127 const std::string& method_name,
128- torch::executor::EventTracer* event_tracer) {
128+ torch::executor::EventTracer* event_tracer,
129+ const std::string& data_map_path) {
129130 if (!is_method_loaded (method_name)) {
130131 ET_CHECK_OK_OR_RETURN_ERROR (load ());
131132
132133 MethodHolder method_holder;
134+
135+ // If we have a .ptd load it.
136+ const runtime::NamedDataMap* named_data_map = nullptr ;
137+ if (!data_map_path.empty ()) {
138+ auto data_map_data_loader =
139+ ET_UNWRAP_UNIQUE (FileDataLoader::from (data_map_path.c_str ()));
140+ auto data_map = ET_UNWRAP_UNIQUE (
141+ executorch::extension::FlatTensorDataMap::load (data_map_data_loader.get ()));
142+ method_holder.data_map_loader = std::move (data_map_data_loader);
143+ method_holder.data_map = std::move (data_map);
144+ }
145+ named_data_map = method_holder.data_map .get ();
146+
133147 const auto method_metadata =
134148 ET_UNWRAP (program_->method_meta (method_name.c_str ()));
135149 const auto planned_buffersCount =
@@ -155,7 +169,8 @@ runtime::Error Module::load_method(
155169 method_holder.method = ET_UNWRAP_UNIQUE (program_->load_method (
156170 method_name.c_str (),
157171 method_holder.memory_manager .get (),
158- event_tracer ? event_tracer : this ->event_tracer ()));
172+ event_tracer ? event_tracer : this ->event_tracer (),
173+ named_data_map));
159174 method_holder.inputs .resize (method_holder.method ->inputs_size ());
160175 methods_.emplace (method_name, std::move (method_holder));
161176 }
0 commit comments