1+ /*
2+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3+ * All rights reserved.
4+ *
5+ * This source code is licensed under the BSD-style license found in the
6+ * LICENSE file in the root directory of this source tree.
7+ */
8+
9+ #pragma once
10+
11+ #include < executorch/extension/tensor/tensor.h>
12+ #include < executorch/runtime/executor/method_meta.h>
13+ #include < runtime/executor/method.h>
14+
15+ namespace executorch {
16+ namespace extension {
17+ namespace llm {
18+
19+ /* *
20+ * @brief Base class for managing input/output operations for LLM inference.
21+ *
22+ * IOManagerBase provides an interface for handling the input preparation and
23+ * output processing for both prefill and decode phases of LLM inference.
24+ * Derived classes must implement the virtual methods to provide specific IO
25+ * management functionality.
26+ */
27+ class ET_EXPERIMENTAL IOManagerBase {
28+ public:
29+ /* *
30+ * @brief Virtual destructor to allow proper cleanup in derived classes.
31+ */
32+ ET_EXPERIMENTAL virtual ~IOManagerBase () = default ;
33+
34+ /* *
35+ * @brief Initialize the IO manager with method metadata for prefill and
36+ * decode operations.
37+ *
38+ * @param prefill_method The prefill method to initialize with.
39+ * @param decode_method The decode method to initialize with.
40+ */
41+ ET_EXPERIMENTAL ET_NODISCARD virtual runtime::Error init (
42+ executorch::runtime::Method& prefill_method,
43+ executorch::runtime::Method& decode_method) = 0;
44+
45+ /* *
46+ * @brief Reset the IO manager state.
47+ *
48+ * @param prefill_method The prefill method to reset with.
49+ * @param decode_method The decode method to reset with.
50+ */
51+ ET_EXPERIMENTAL ET_NODISCARD virtual runtime::Error reset (
52+ executorch::runtime::Method& prefill_method,
53+ executorch::runtime::Method& decode_method) = 0;
54+
55+ /* *
56+ * @brief Prepare inputs for the prefill phase of LLM inference.
57+ *
58+ * @param input The input tensor containing token IDs.
59+ * @param start_pos The tensor containing the starting position of the current
60+ * input within the context.
61+ * @param prefill_method The prefill method to prepare inputs for.
62+ * @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
63+ * for the prefill method.
64+ */
65+ ET_EXPERIMENTAL virtual runtime::Result<
66+ std::vector<executorch::runtime::EValue>>
67+ prepare_prefill (
68+ const executorch::extension::TensorPtr& input,
69+ const executorch::extension::TensorPtr& start_pos,
70+ executorch::runtime::Method& prefill_method) = 0 ;
71+
72+ /* *
73+ * @brief Prepare inputs for the decode phase of LLM inference.
74+ *
75+ * @param input The input tensor containing token IDs.
76+ * @param start_pos The tensor containing the starting position of the current
77+ * input within the context.
78+ * @param decode_method The decode method to prepare inputs for.
79+ * @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
80+ * for the decode method.
81+ */
82+ ET_EXPERIMENTAL virtual runtime::Result<
83+ std::vector<executorch::runtime::EValue>>
84+ prepare_decode (
85+ const executorch::extension::TensorPtr& input,
86+ const executorch::extension::TensorPtr& start_pos,
87+ executorch::runtime::Method& decode_method) = 0 ;
88+
89+ /* *
90+ * @brief Process and update internal state with outputs from the prefill
91+ * phase.
92+ *
93+ * @param prefill_method The prefill method to update with outputs.
94+ * @param model_outputs Vector of outputs from the prefill method execution.
95+ */
96+ ET_EXPERIMENTAL ET_NODISCARD virtual runtime::Error update_prefill (
97+ executorch::runtime::Method& prefill_method,
98+ const std::vector<executorch::runtime::EValue>& model_outputs) = 0;
99+
100+ /* *
101+ * @brief Process and update internal state with outputs from the decode
102+ * phase.
103+ *
104+ * @param decode_method The decode method to update with outputs.
105+ * @param model_outputs Vector of outputs from the decode method execution.
106+ */
107+ ET_EXPERIMENTAL ET_NODISCARD virtual runtime::Error update_decode (
108+ const executorch::runtime::Method& decode_method,
109+ const std::vector<executorch::runtime::EValue>& model_outputs) = 0;
110+ };
111+
112+ } // namespace llm
113+ } // namespace extension
114+ } // namespace executorch
115+
0 commit comments