1414#include < unordered_set>
1515#include < vector>
1616
17+ #include < executorch/extension/flat_tensor/flat_tensor_data_map.h>
1718#include < executorch/runtime/executor/program.h>
1819
1920namespace executorch {
@@ -133,6 +134,8 @@ class Module {
133134 * needed. The loaded method is cached to reuse the next time it's executed.
134135 *
135136 * @param[in] method_name The name of the method to load.
137+ * @param[in] data_map_path Path to a .ptd file containing weights
138+ * for this method.
136139 * @param[in] event_tracer Per-method event tracer to profile/trace methods
137140 * individually. When not given, the event tracer passed to the Module
138141 * constructor is used. Otherwise, this per-method event tracer takes
@@ -143,8 +146,29 @@ class Module {
143146 ET_NODISCARD
144147 runtime::Error load_method (
145148 const std::string& method_name,
149+ const std::string& data_map_path,
146150 torch::executor::EventTracer* event_tracer = nullptr );
147151
152+ /* *
153+ * Load a specific method from the program and set up memory management if
154+ * needed. The loaded method is cached to reuse the next time it's executed.
155+ *
156+ * @param[in] method_name The name of the method to load.
157+ * @param[in] event_tracer Per-method event tracer to profile/trace methods
158+ * individually. When not given, the event tracer passed to the Module
159+ * constructor is used. Otherwise, this per-method event tracer takes
160+ * precedence.
161+ * @param[in] data_map_data_loader Optional data loader for the .ptd file
162+ * for this method.
163+ *
164+ * @returns An Error to indicate success or failure.
165+ */
166+ ET_NODISCARD
167+ runtime::Error load_method (
168+ const std::string& method_name,
169+ torch::executor::EventTracer* event_tracer = nullptr ,
170+ std::unique_ptr<runtime::DataLoader> data_map_data_loader = nullptr );
171+
148172 /* *
149173 * Load the 'forward' method from the program and set up memory management if
150174 * needed. The loaded method is cached to reuse the next time it's executed.
@@ -155,8 +179,9 @@ class Module {
155179 * @returns An Error to indicate success or failure.
156180 */
157181 ET_NODISCARD inline runtime::Error load_forward (
158- torch::executor::EventTracer* event_tracer = nullptr ) {
159- return load_method (" forward" , event_tracer);
182+ torch::executor::EventTracer* event_tracer = nullptr ,
183+ std::unique_ptr<runtime::DataLoader> data_map_data_loader = nullptr ) {
184+ return load_method (" forward" , event_tracer, std::move (data_map_data_loader));
160185 }
161186
162187 /* *
@@ -430,10 +455,11 @@ class Module {
430455 std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
431456 std::unique_ptr<runtime::MemoryManager> memory_manager;
432457 std::unique_ptr<runtime::Method> method;
458+ std::unique_ptr<runtime::DataLoader> data_map_loader;
459+ std::unique_ptr<extension::FlatTensorDataMap> data_map;
433460 std::vector<runtime::EValue> inputs;
434461 };
435462
436- private:
437463 std::string file_path_;
438464 LoadMode load_mode_{LoadMode::MmapUseMlock};
439465 std::shared_ptr<runtime::Program> program_;
0 commit comments