Skip to content

Commit 2603f74

Browse files
authored
Make IOManager use Module instead of Method. (#13542)
Summary: Let's not expose Method from Module so that it's not getting misused beyond its owner. Differential Revision: D80595261
1 parent 3db27cd commit 2603f74

File tree

10 files changed

+190
-180
lines changed

10 files changed

+190
-180
lines changed

examples/models/llava/runner/llava_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ET_EXPERIMENTAL LlavaRunner {
4242
const float temperature = 0.8f)
4343
: temperature_(temperature),
4444
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
45-
io_manager_(std::make_unique<IOManager>()),
45+
io_manager_(std::make_unique<IOManager>(*module_)),
4646
tokenizer_path_(tokenizer_path) {
4747
ET_LOG(
4848
Info,

extension/llm/runner/io_manager/io_manager.h

Lines changed: 124 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,8 @@
88

99
#pragma once
1010

11-
#include <vector>
12-
11+
#include <executorch/extension/module/module.h>
1312
#include <executorch/extension/tensor/tensor.h>
14-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15-
#include <executorch/runtime/executor/method.h>
16-
#include <executorch/runtime/executor/method_meta.h>
1713

1814
namespace executorch {
1915
namespace extension {
@@ -29,6 +25,13 @@ namespace llm {
2925
*/
3026
class ET_EXPERIMENTAL IOManager {
3127
public:
28+
/**
29+
* @brief Construct an IOManager bound to a Module.
30+
*
31+
* @param module The Module used for querying method metadata and execution.
32+
*/
33+
explicit IOManager(ET_MODULE_NAMESPACE::Module& module) : module_(module) {}
34+
3235
/**
3336
* @brief Virtual destructor to allow proper cleanup in derived classes.
3437
*/
@@ -38,88 +41,143 @@ class ET_EXPERIMENTAL IOManager {
3841
* @brief Load the IO manager with method metadata for prefill and
3942
* decode operations.
4043
*
41-
* @param program The program prefill and decode methods are loaded from.
4244
* @param prefill_method The prefill method to initialize with.
4345
* @param decode_method The decode method to initialize with.
4446
*/
4547
ET_NODISCARD virtual runtime::Error load(
46-
const executorch::ET_RUNTIME_NAMESPACE::Program& program,
47-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
48-
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
49-
(void)program;
48+
const std::string& prefill_method,
49+
const std::string& decode_method) {
5050
(void)prefill_method;
5151
(void)decode_method;
5252
return runtime::Error::Ok;
5353
}
5454

55+
/**
56+
* @brief Load the IO manager using the default method names.
57+
*
58+
* Uses "forward" for both prefill and decode.
59+
*
60+
* @return Error code.
61+
*/
62+
ET_NODISCARD runtime::Error load() {
63+
return load("forward", "forward");
64+
}
65+
5566
/**
5667
* @brief Reset the IO manager state.
5768
*
5869
* @param prefill_method The prefill method to reset with.
5970
* @param decode_method The decode method to reset with.
6071
*/
6172
ET_NODISCARD virtual runtime::Error reset(
62-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
63-
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
73+
const std::string& prefill_method,
74+
const std::string& decode_method) {
6475
(void)prefill_method;
6576
(void)decode_method;
6677
return runtime::Error::Ok;
6778
}
6879

80+
/**
81+
* @brief Reset the IO manager state using the default method names.
82+
*
83+
* Uses "forward" for both prefill and decode.
84+
*
85+
* @return Error code.
86+
*/
87+
ET_NODISCARD runtime::Error reset() {
88+
return reset("forward", "forward");
89+
}
90+
6991
/**
7092
* @brief Prepare inputs for the prefill phase of LLM inference.
7193
*
7294
* @param input The input tensor containing token IDs.
7395
* @param start_pos The tensor containing the starting position of the current
7496
* input within the context.
7597
* @param prefill_method The prefill method to prepare inputs for.
76-
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
98+
* @return std::vector<runtime::EValue> Vector of prepared inputs
7799
* for the prefill method.
78100
*/
79-
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
80-
prepare_prefill(
81-
const executorch::extension::TensorPtr& input,
82-
const executorch::extension::TensorPtr& start_pos,
83-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) {
84-
if (prefill_method.inputs_size() != 2) {
101+
virtual runtime::Result<std::vector<runtime::EValue>> prepare_prefill(
102+
const TensorPtr& input,
103+
const TensorPtr& start_pos,
104+
const std::string& prefill_method) {
105+
auto method_meta = module_.method_meta(prefill_method);
106+
if (!method_meta.ok()) {
107+
return method_meta.error();
108+
}
109+
if (method_meta->num_inputs() != 2) {
85110
ET_LOG(
86111
Error,
87112
"Expected 2 inputs for prefill method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.",
88-
prefill_method.inputs_size());
113+
method_meta->num_inputs());
89114
return runtime::Error::InvalidState;
90115
}
91116
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
92117
// here.
93118
return std::vector<runtime::EValue>{input, start_pos};
94119
}
95120

121+
/**
122+
* @brief Prepare inputs for the prefill phase using the default method name.
123+
*
124+
* Uses "forward" as the prefill method.
125+
*
126+
* @param input The input tensor containing token IDs.
127+
* @param start_pos The tensor containing the starting position.
128+
* @return Vector of prepared inputs for the prefill method.
129+
*/
130+
runtime::Result<std::vector<runtime::EValue>> prepare_prefill(
131+
const TensorPtr& input,
132+
const TensorPtr& start_pos) {
133+
return prepare_prefill(input, start_pos, "forward");
134+
}
135+
96136
/**
97137
* @brief Prepare inputs for the decode phase of LLM inference.
98138
*
99139
* @param input The input tensor containing token IDs.
100140
* @param start_pos The tensor containing the starting position of the current
101141
* input within the context.
102142
* @param decode_method The decode method to prepare inputs for.
103-
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
143+
* @return std::vector<runtime::EValue> Vector of prepared inputs
104144
* for the decode method.
105145
*/
106-
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
107-
prepare_decode(
108-
const executorch::extension::TensorPtr& input,
109-
const executorch::extension::TensorPtr& start_pos,
110-
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
111-
if (decode_method.inputs_size() != 2) {
146+
virtual runtime::Result<std::vector<runtime::EValue>> prepare_decode(
147+
const TensorPtr& input,
148+
const TensorPtr& start_pos,
149+
const std::string& decode_method) {
150+
auto method_meta = module_.method_meta(decode_method);
151+
if (!method_meta.ok()) {
152+
return method_meta.error();
153+
}
154+
if (method_meta->num_inputs() != 2) {
112155
ET_LOG(
113156
Error,
114157
"Expected 2 inputs for decode method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.",
115-
decode_method.inputs_size());
158+
method_meta->num_inputs());
116159
return runtime::Error::InvalidState;
117160
}
118161
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
119162
// here.
120163
return std::vector<runtime::EValue>{input, start_pos};
121164
}
122165

166+
/**
167+
* @brief Prepare inputs for the decode phase using the default method name.
168+
*
169+
* Uses "forward" as the decode method.
170+
*
171+
* @param input The input tensor containing token IDs.
172+
* @param start_pos The tensor containing the starting position.
173+
* @return Vector of prepared inputs for the decode method.
174+
*/
175+
runtime::Result<std::vector<runtime::EValue>> prepare_decode(
176+
const TensorPtr& input,
177+
const TensorPtr& start_pos) {
178+
return prepare_decode(input, start_pos, "forward");
179+
}
180+
123181
/**
124182
* @brief Process and update internal state with outputs from the prefill
125183
* phase.
@@ -128,14 +186,27 @@ class ET_EXPERIMENTAL IOManager {
128186
* @param model_outputs Vector of outputs from the prefill method execution.
129187
*/
130188
ET_NODISCARD virtual runtime::Error update_prefill(
131-
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
132-
const std::vector<executorch::runtime::EValue>& model_outputs) {
133-
(void)prefill_method;
189+
const std::vector<runtime::EValue>& model_outputs,
190+
const std::string& prefill_method) {
134191
(void)model_outputs;
192+
(void)prefill_method;
135193
// No post inference work to do.
136194
return runtime::Error::Ok;
137195
}
138196

197+
/**
198+
* @brief Process outputs from the prefill phase using the default method.
199+
*
200+
* Uses "forward" as the prefill method.
201+
*
202+
* @param model_outputs Vector of outputs from the prefill execution.
203+
* @return Error code.
204+
*/
205+
ET_NODISCARD runtime::Error update_prefill(
206+
const std::vector<runtime::EValue>& model_outputs) {
207+
return update_prefill(model_outputs, "forward");
208+
}
209+
139210
/**
140211
* @brief Process and update internal state with outputs from the decode
141212
* phase.
@@ -144,13 +215,32 @@ class ET_EXPERIMENTAL IOManager {
144215
* @param model_outputs Vector of outputs from the decode method execution.
145216
*/
146217
ET_NODISCARD virtual runtime::Error update_decode(
147-
const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
148-
const std::vector<executorch::runtime::EValue>& model_outputs) {
149-
(void)decode_method;
218+
const std::vector<runtime::EValue>& model_outputs,
219+
const std::string& decode_method) {
150220
(void)model_outputs;
221+
(void)decode_method;
151222
// No post inference work to do.
152223
return runtime::Error::Ok;
153224
}
225+
226+
/**
227+
* @brief Process outputs from the decode phase using the default method.
228+
*
229+
* Uses "forward" as the decode method.
230+
*
231+
* @param model_outputs Vector of outputs from the decode execution.
232+
* @return Error code.
233+
*/
234+
ET_NODISCARD runtime::Error update_decode(
235+
const std::vector<runtime::EValue>& model_outputs) {
236+
return update_decode(model_outputs, "forward");
237+
}
238+
239+
private:
240+
/**
241+
* @brief Reference to the Module used for method metadata and execution.
242+
*/
243+
ET_MODULE_NAMESPACE::Module& module_;
154244
};
155245

156246
} // namespace llm

extension/llm/runner/io_manager/targets.bzl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ def define_common_targets():
1111
exported_headers = [
1212
"io_manager.h",
1313
],
14-
deps = [
14+
exported_deps = [
1515
"//executorch/extension/tensor:tensor" + aten_suffix,
16-
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
17-
"//executorch/runtime/executor:program_no_prim_ops" + aten_suffix,
16+
"//executorch/extension/module:module" + aten_suffix,
1817
],
1918
visibility = [
2019
"@EXECUTORCH_CLIENTS",

extension/llm/runner/io_manager/test/TARGETS

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,12 @@ define_common_targets()
1010

1111
runtime.cxx_test(
1212
name = "test_io_manager",
13-
srcs = ["test_io_manager.cpp"],
13+
srcs = [
14+
"test_io_manager.cpp",
15+
],
1416
deps = [
1517
"//executorch/extension/llm/runner/io_manager:io_manager",
16-
"//executorch/extension/llm/runner/io_manager:io_manager",
17-
"//executorch/extension/module:module",
18-
"//executorch/extension/tensor:tensor",
19-
"//executorch/runtime/executor:program",
20-
"//executorch/kernels/portable:generated_lib",
18+
"//executorch/kernels/portable:generated_lib",
2119
],
2220
env = {
2321
"KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])",

0 commit comments

Comments
 (0)