8
8
9
9
#pragma once
10
10
11
- #include < vector>
12
-
11
+ #include < executorch/extension/module/module.h>
13
12
#include < executorch/extension/tensor/tensor.h>
14
- #include < executorch/runtime/core/exec_aten/exec_aten.h>
15
- #include < executorch/runtime/executor/method.h>
16
- #include < executorch/runtime/executor/method_meta.h>
17
13
18
14
namespace executorch {
19
15
namespace extension {
@@ -29,6 +25,13 @@ namespace llm {
29
25
*/
30
26
class ET_EXPERIMENTAL IOManager {
31
27
public:
28
+ /* *
29
+ * @brief Construct an IOManager bound to a Module.
30
+ *
31
+ * @param module The Module used for querying method metadata and execution.
32
+ */
33
+ explicit IOManager (ET_MODULE_NAMESPACE::Module& module ) : module_(module ) {}
34
+
32
35
/* *
33
36
* @brief Virtual destructor to allow proper cleanup in derived classes.
34
37
*/
@@ -38,88 +41,143 @@ class ET_EXPERIMENTAL IOManager {
38
41
* @brief Load the IO manager with method metadata for prefill and
39
42
* decode operations.
40
43
*
41
- * @param program The program prefill and decode methods are loaded from.
42
44
* @param prefill_method The prefill method to initialize with.
43
45
* @param decode_method The decode method to initialize with.
44
46
*/
45
47
ET_NODISCARD virtual runtime::Error load (
46
- const executorch::ET_RUNTIME_NAMESPACE::Program& program,
47
- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
48
- executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
49
- (void )program;
48
+ const std::string& prefill_method,
49
+ const std::string& decode_method) {
50
50
(void )prefill_method;
51
51
(void )decode_method;
52
52
return runtime::Error::Ok;
53
53
}
54
54
55
+ /* *
56
+ * @brief Load the IO manager using the default method names.
57
+ *
58
+ * Uses "forward" for both prefill and decode.
59
+ *
60
+ * @return Error code.
61
+ */
62
+ ET_NODISCARD runtime::Error load () {
63
+ return load (" forward" , " forward" );
64
+ }
65
+
55
66
/* *
56
67
* @brief Reset the IO manager state.
57
68
*
58
69
* @param prefill_method The prefill method to reset with.
59
70
* @param decode_method The decode method to reset with.
60
71
*/
61
72
ET_NODISCARD virtual runtime::Error reset (
62
- executorch::ET_RUNTIME_NAMESPACE::Method & prefill_method,
63
- executorch::ET_RUNTIME_NAMESPACE::Method & decode_method) {
73
+ const std::string & prefill_method,
74
+ const std::string & decode_method) {
64
75
(void )prefill_method;
65
76
(void )decode_method;
66
77
return runtime::Error::Ok;
67
78
}
68
79
80
+ /* *
81
+ * @brief Reset the IO manager state using the default method names.
82
+ *
83
+ * Uses "forward" for both prefill and decode.
84
+ *
85
+ * @return Error code.
86
+ */
87
+ ET_NODISCARD runtime::Error reset () {
88
+ return reset (" forward" , " forward" );
89
+ }
90
+
69
91
/* *
70
92
* @brief Prepare inputs for the prefill phase of LLM inference.
71
93
*
72
94
* @param input The input tensor containing token IDs.
73
95
* @param start_pos The tensor containing the starting position of the current
74
96
* input within the context.
75
97
* @param prefill_method The prefill method to prepare inputs for.
76
- * @return std::vector<executorch:: runtime::EValue> Vector of prepared inputs
98
+ * @return std::vector<runtime::EValue> Vector of prepared inputs
77
99
* for the prefill method.
78
100
*/
79
- virtual runtime::Result<std::vector<executorch::runtime::EValue>>
80
- prepare_prefill (
81
- const executorch::extension::TensorPtr& input,
82
- const executorch::extension::TensorPtr& start_pos,
83
- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) {
84
- if (prefill_method.inputs_size () != 2 ) {
101
+ virtual runtime::Result<std::vector<runtime::EValue>> prepare_prefill (
102
+ const TensorPtr& input,
103
+ const TensorPtr& start_pos,
104
+ const std::string& prefill_method) {
105
+ auto method_meta = module_.method_meta (prefill_method);
106
+ if (!method_meta.ok ()) {
107
+ return method_meta.error ();
108
+ }
109
+ if (method_meta->num_inputs () != 2 ) {
85
110
ET_LOG (
86
111
Error,
87
112
" Expected 2 inputs for prefill method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support." ,
88
- prefill_method. inputs_size ());
113
+ method_meta-> num_inputs ());
89
114
return runtime::Error::InvalidState;
90
115
}
91
116
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
92
117
// here.
93
118
return std::vector<runtime::EValue>{input, start_pos};
94
119
}
95
120
121
+ /* *
122
+ * @brief Prepare inputs for the prefill phase using the default method name.
123
+ *
124
+ * Uses "forward" as the prefill method.
125
+ *
126
+ * @param input The input tensor containing token IDs.
127
+ * @param start_pos The tensor containing the starting position.
128
+ * @return Vector of prepared inputs for the prefill method.
129
+ */
130
+ runtime::Result<std::vector<runtime::EValue>> prepare_prefill (
131
+ const TensorPtr& input,
132
+ const TensorPtr& start_pos) {
133
+ return prepare_prefill (input, start_pos, " forward" );
134
+ }
135
+
96
136
/* *
97
137
* @brief Prepare inputs for the decode phase of LLM inference.
98
138
*
99
139
* @param input The input tensor containing token IDs.
100
140
* @param start_pos The tensor containing the starting position of the current
101
141
* input within the context.
102
142
* @param decode_method The decode method to prepare inputs for.
103
- * @return std::vector<executorch:: runtime::EValue> Vector of prepared inputs
143
+ * @return std::vector<runtime::EValue> Vector of prepared inputs
104
144
* for the decode method.
105
145
*/
106
- virtual runtime::Result<std::vector<executorch::runtime::EValue>>
107
- prepare_decode (
108
- const executorch::extension::TensorPtr& input,
109
- const executorch::extension::TensorPtr& start_pos,
110
- executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
111
- if (decode_method.inputs_size () != 2 ) {
146
+ virtual runtime::Result<std::vector<runtime::EValue>> prepare_decode (
147
+ const TensorPtr& input,
148
+ const TensorPtr& start_pos,
149
+ const std::string& decode_method) {
150
+ auto method_meta = module_.method_meta (decode_method);
151
+ if (!method_meta.ok ()) {
152
+ return method_meta.error ();
153
+ }
154
+ if (method_meta->num_inputs () != 2 ) {
112
155
ET_LOG (
113
156
Error,
114
157
" Expected 2 inputs for decode method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support." ,
115
- decode_method. inputs_size ());
158
+ method_meta-> num_inputs ());
116
159
return runtime::Error::InvalidState;
117
160
}
118
161
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
119
162
// here.
120
163
return std::vector<runtime::EValue>{input, start_pos};
121
164
}
122
165
166
+ /* *
167
+ * @brief Prepare inputs for the decode phase using the default method name.
168
+ *
169
+ * Uses "forward" as the decode method.
170
+ *
171
+ * @param input The input tensor containing token IDs.
172
+ * @param start_pos The tensor containing the starting position.
173
+ * @return Vector of prepared inputs for the decode method.
174
+ */
175
+ runtime::Result<std::vector<runtime::EValue>> prepare_decode (
176
+ const TensorPtr& input,
177
+ const TensorPtr& start_pos) {
178
+ return prepare_decode (input, start_pos, " forward" );
179
+ }
180
+
123
181
/* *
124
182
* @brief Process and update internal state with outputs from the prefill
125
183
* phase.
@@ -128,14 +186,27 @@ class ET_EXPERIMENTAL IOManager {
128
186
* @param model_outputs Vector of outputs from the prefill method execution.
129
187
*/
130
188
ET_NODISCARD virtual runtime::Error update_prefill (
131
- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
132
- const std::vector<executorch::runtime::EValue>& model_outputs) {
133
- (void )prefill_method;
189
+ const std::vector<runtime::EValue>& model_outputs,
190
+ const std::string& prefill_method) {
134
191
(void )model_outputs;
192
+ (void )prefill_method;
135
193
// No post inference work to do.
136
194
return runtime::Error::Ok;
137
195
}
138
196
197
+ /* *
198
+ * @brief Process outputs from the prefill phase using the default method.
199
+ *
200
+ * Uses "forward" as the prefill method.
201
+ *
202
+ * @param model_outputs Vector of outputs from the prefill execution.
203
+ * @return Error code.
204
+ */
205
+ ET_NODISCARD runtime::Error update_prefill (
206
+ const std::vector<runtime::EValue>& model_outputs) {
207
+ return update_prefill (model_outputs, " forward" );
208
+ }
209
+
139
210
/* *
140
211
* @brief Process and update internal state with outputs from the decode
141
212
* phase.
@@ -144,13 +215,32 @@ class ET_EXPERIMENTAL IOManager {
144
215
* @param model_outputs Vector of outputs from the decode method execution.
145
216
*/
146
217
ET_NODISCARD virtual runtime::Error update_decode (
147
- const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
148
- const std::vector<executorch::runtime::EValue>& model_outputs) {
149
- (void )decode_method;
218
+ const std::vector<runtime::EValue>& model_outputs,
219
+ const std::string& decode_method) {
150
220
(void )model_outputs;
221
+ (void )decode_method;
151
222
// No post inference work to do.
152
223
return runtime::Error::Ok;
153
224
}
225
+
226
+ /* *
227
+ * @brief Process outputs from the decode phase using the default method.
228
+ *
229
+ * Uses "forward" as the decode method.
230
+ *
231
+ * @param model_outputs Vector of outputs from the decode execution.
232
+ * @return Error code.
233
+ */
234
+ ET_NODISCARD runtime::Error update_decode (
235
+ const std::vector<runtime::EValue>& model_outputs) {
236
+ return update_decode (model_outputs, " forward" );
237
+ }
238
+
239
+ private:
240
+ /* *
241
+ * @brief Reference to the Module used for method metadata and execution.
242
+ */
243
+ ET_MODULE_NAMESPACE::Module& module_;
154
244
};
155
245
156
246
} // namespace llm
0 commit comments