88
99#pragma once
1010
11- #include < vector>
12-
11+ #include < executorch/extension/module/module.h>
1312#include < executorch/extension/tensor/tensor.h>
14- #include < executorch/runtime/core/exec_aten/exec_aten.h>
15- #include < executorch/runtime/executor/method.h>
16- #include < executorch/runtime/executor/method_meta.h>
1713
1814namespace executorch {
1915namespace extension {
@@ -29,6 +25,13 @@ namespace llm {
2925 */
3026class ET_EXPERIMENTAL IOManager {
3127 public:
28+ /* *
29+ * @brief Construct an IOManager bound to a Module.
30+ *
31+ * @param module The Module used for querying method metadata and execution.
32+ */
33+ explicit IOManager (ET_MODULE_NAMESPACE::Module& module ) : module_(module ) {}
34+
3235 /* *
3336 * @brief Virtual destructor to allow proper cleanup in derived classes.
3437 */
@@ -38,88 +41,143 @@ class ET_EXPERIMENTAL IOManager {
3841 * @brief Load the IO manager with method metadata for prefill and
3942 * decode operations.
4043 *
41- * @param program The program prefill and decode methods are loaded from.
4244 * @param prefill_method The prefill method to initialize with.
4345 * @param decode_method The decode method to initialize with.
4446 */
4547 ET_NODISCARD virtual runtime::Error load (
46- const executorch::ET_RUNTIME_NAMESPACE::Program& program,
47- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
48- executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
49- (void )program;
48+ const std::string& prefill_method,
49+ const std::string& decode_method) {
5050 (void )prefill_method;
5151 (void )decode_method;
5252 return runtime::Error::Ok;
5353 }
5454
55+ /* *
56+ * @brief Load the IO manager using the default method names.
57+ *
58+ * Uses "forward" for both prefill and decode.
59+ *
60+ * @return Error code.
61+ */
62+ ET_NODISCARD runtime::Error load () {
63+ return load (" forward" , " forward" );
64+ }
65+
5566 /* *
5667 * @brief Reset the IO manager state.
5768 *
5869 * @param prefill_method The prefill method to reset with.
5970 * @param decode_method The decode method to reset with.
6071 */
6172 ET_NODISCARD virtual runtime::Error reset (
62- executorch::ET_RUNTIME_NAMESPACE::Method & prefill_method,
63- executorch::ET_RUNTIME_NAMESPACE::Method & decode_method) {
73+ const std::string & prefill_method,
74+ const std::string & decode_method) {
6475 (void )prefill_method;
6576 (void )decode_method;
6677 return runtime::Error::Ok;
6778 }
6879
80+ /* *
81+ * @brief Reset the IO manager state using the default method names.
82+ *
83+ * Uses "forward" for both prefill and decode.
84+ *
85+ * @return Error code.
86+ */
87+ ET_NODISCARD runtime::Error reset () {
88+ return reset (" forward" , " forward" );
89+ }
90+
6991 /* *
7092 * @brief Prepare inputs for the prefill phase of LLM inference.
7193 *
7294 * @param input The input tensor containing token IDs.
7395 * @param start_pos The tensor containing the starting position of the current
7496 * input within the context.
7597 * @param prefill_method The prefill method to prepare inputs for.
76- * @return std::vector<executorch:: runtime::EValue> Vector of prepared inputs
98+ * @return std::vector<runtime::EValue> Vector of prepared inputs
7799 * for the prefill method.
78100 */
79- virtual runtime::Result<std::vector<executorch::runtime::EValue>>
80- prepare_prefill (
81- const executorch::extension::TensorPtr& input,
82- const executorch::extension::TensorPtr& start_pos,
83- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) {
84- if (prefill_method.inputs_size () != 2 ) {
101+ virtual runtime::Result<std::vector<runtime::EValue>> prepare_prefill (
102+ const TensorPtr& input,
103+ const TensorPtr& start_pos,
104+ const std::string& prefill_method) {
105+ auto method_meta = module_.method_meta (prefill_method);
106+ if (!method_meta.ok ()) {
107+ return method_meta.error ();
108+ }
109+ if (method_meta->num_inputs () != 2 ) {
85110 ET_LOG (
86111 Error,
87112 " Expected 2 inputs for prefill method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support." ,
88- prefill_method. inputs_size ());
113+ method_meta-> num_inputs ());
89114 return runtime::Error::InvalidState;
90115 }
91116 // Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
92117 // here.
93118 return std::vector<runtime::EValue>{input, start_pos};
94119 }
95120
121+ /* *
122+ * @brief Prepare inputs for the prefill phase using the default method name.
123+ *
124+ * Uses "forward" as the prefill method.
125+ *
126+ * @param input The input tensor containing token IDs.
127+ * @param start_pos The tensor containing the starting position.
128+ * @return Vector of prepared inputs for the prefill method.
129+ */
130+ runtime::Result<std::vector<runtime::EValue>> prepare_prefill (
131+ const TensorPtr& input,
132+ const TensorPtr& start_pos) {
133+ return prepare_prefill (input, start_pos, " forward" );
134+ }
135+
96136 /* *
97137 * @brief Prepare inputs for the decode phase of LLM inference.
98138 *
99139 * @param input The input tensor containing token IDs.
100140 * @param start_pos The tensor containing the starting position of the current
101141 * input within the context.
102142 * @param decode_method The decode method to prepare inputs for.
103- * @return std::vector<executorch:: runtime::EValue> Vector of prepared inputs
143+ * @return std::vector<runtime::EValue> Vector of prepared inputs
104144 * for the decode method.
105145 */
106- virtual runtime::Result<std::vector<executorch::runtime::EValue>>
107- prepare_decode (
108- const executorch::extension::TensorPtr& input,
109- const executorch::extension::TensorPtr& start_pos,
110- executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
111- if (decode_method.inputs_size () != 2 ) {
146+ virtual runtime::Result<std::vector<runtime::EValue>> prepare_decode (
147+ const TensorPtr& input,
148+ const TensorPtr& start_pos,
149+ const std::string& decode_method) {
150+ auto method_meta = module_.method_meta (decode_method);
151+ if (!method_meta.ok ()) {
152+ return method_meta.error ();
153+ }
154+ if (method_meta->num_inputs () != 2 ) {
112155 ET_LOG (
113156 Error,
114157 " Expected 2 inputs for decode method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support." ,
115- decode_method. inputs_size ());
158+ method_meta-> num_inputs ());
116159 return runtime::Error::InvalidState;
117160 }
118161 // Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
119162 // here.
120163 return std::vector<runtime::EValue>{input, start_pos};
121164 }
122165
166+ /* *
167+ * @brief Prepare inputs for the decode phase using the default method name.
168+ *
169+ * Uses "forward" as the decode method.
170+ *
171+ * @param input The input tensor containing token IDs.
172+ * @param start_pos The tensor containing the starting position.
173+ * @return Vector of prepared inputs for the decode method.
174+ */
175+ runtime::Result<std::vector<runtime::EValue>> prepare_decode (
176+ const TensorPtr& input,
177+ const TensorPtr& start_pos) {
178+ return prepare_decode (input, start_pos, " forward" );
179+ }
180+
123181 /* *
124182 * @brief Process and update internal state with outputs from the prefill
125183 * phase.
@@ -128,14 +186,27 @@ class ET_EXPERIMENTAL IOManager {
128186 * @param model_outputs Vector of outputs from the prefill method execution.
129187 */
130188 ET_NODISCARD virtual runtime::Error update_prefill (
131- executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
132- const std::vector<executorch::runtime::EValue>& model_outputs) {
133- (void )prefill_method;
189+ const std::vector<runtime::EValue>& model_outputs,
190+ const std::string& prefill_method) {
134191 (void )model_outputs;
192+ (void )prefill_method;
135193 // No post inference work to do.
136194 return runtime::Error::Ok;
137195 }
138196
197+ /* *
198+ * @brief Process outputs from the prefill phase using the default method.
199+ *
200+ * Uses "forward" as the prefill method.
201+ *
202+ * @param model_outputs Vector of outputs from the prefill execution.
203+ * @return Error code.
204+ */
205+ ET_NODISCARD runtime::Error update_prefill (
206+ const std::vector<runtime::EValue>& model_outputs) {
207+ return update_prefill (model_outputs, " forward" );
208+ }
209+
139210 /* *
140211 * @brief Process and update internal state with outputs from the decode
141212 * phase.
@@ -144,13 +215,32 @@ class ET_EXPERIMENTAL IOManager {
144215 * @param model_outputs Vector of outputs from the decode method execution.
145216 */
146217 ET_NODISCARD virtual runtime::Error update_decode (
147- const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
148- const std::vector<executorch::runtime::EValue>& model_outputs) {
149- (void )decode_method;
218+ const std::vector<runtime::EValue>& model_outputs,
219+ const std::string& decode_method) {
150220 (void )model_outputs;
221+ (void )decode_method;
151222 // No post inference work to do.
152223 return runtime::Error::Ok;
153224 }
225+
226+ /* *
227+ * @brief Process outputs from the decode phase using the default method.
228+ *
229+ * Uses "forward" as the decode method.
230+ *
231+ * @param model_outputs Vector of outputs from the decode execution.
232+ * @return Error code.
233+ */
234+ ET_NODISCARD runtime::Error update_decode (
235+ const std::vector<runtime::EValue>& model_outputs) {
236+ return update_decode (model_outputs, " forward" );
237+ }
238+
239+ private:
240+ /* *
241+ * @brief Reference to the Module used for method metadata and execution.
242+ */
243+ ET_MODULE_NAMESPACE::Module& module_;
154244};
155245
156246} // namespace llm
0 commit comments