Skip to content

Commit 4fcd995

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
IOManager Interface (#10418)
Summary: Pull Request resolved: #10418 Hopefully this is sufficient for the contract. Going to do 2 follow up tests. Add a basic cpu implementation add a static attention implementation. Differential Revision: D73450877
1 parent 76835e8 commit 4fcd995

File tree

3 files changed

+148
-0
lines changed

3 files changed

+148
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/tensor/tensor.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/executor/method.h>
14+
#include <executorch/runtime/executor/method_meta.h>
15+
16+
namespace executorch {
17+
namespace extension {
18+
namespace llm {
19+
20+
/**
21+
* @brief Base class for managing input/output operations for LLM inference.
22+
*
23+
* IOManagerBase provides an interface for handling the input preparation and
24+
* output processing for both prefill and decode phases of LLM inference.
25+
* Derived classes must implement the virtual methods to provide specific IO
26+
* management functionality.
27+
*/
28+
class ET_EXPERIMENTAL IOManagerBase {
29+
public:
30+
/**
31+
* @brief Virtual destructor to allow proper cleanup in derived classes.
32+
*/
33+
virtual ~IOManagerBase() = default;
34+
35+
/**
36+
* @brief Load the IO manager with method metadata for prefill and
37+
* decode operations.
38+
*
39+
* @param prefill_method The prefill method to initialize with.
40+
* @param decode_method The decode method to initialize with.
41+
*/
42+
ET_NODISCARD virtual runtime::Error load(
43+
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
44+
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) = 0;
45+
46+
/**
47+
* @brief Returns if the IOManager is loaded
48+
*/
49+
virtual bool is_loaded() = 0;
50+
51+
/**
52+
* @brief Reset the IO manager state.
53+
*
54+
* @param prefill_method The prefill method to reset with.
55+
* @param decode_method The decode method to reset with.
56+
*/
57+
ET_NODISCARD virtual runtime::Error reset(
58+
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
59+
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) = 0;
60+
61+
/**
62+
* @brief Prepare inputs for the prefill phase of LLM inference.
63+
*
64+
* @param input The input tensor containing token IDs.
65+
* @param start_pos The tensor containing the starting position of the current
66+
* input within the context.
67+
* @param prefill_method The prefill method to prepare inputs for.
68+
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
69+
* for the prefill method.
70+
*/
71+
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
72+
prepare_prefill(
73+
const executorch::extension::TensorPtr& input,
74+
const executorch::extension::TensorPtr& start_pos,
75+
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) = 0;
76+
77+
/**
78+
* @brief Prepare inputs for the decode phase of LLM inference.
79+
*
80+
* @param input The input tensor containing token IDs.
81+
* @param start_pos The tensor containing the starting position of the current
82+
* input within the context.
83+
* @param decode_method The decode method to prepare inputs for.
84+
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
85+
* for the decode method.
86+
*/
87+
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
88+
prepare_decode(
89+
const executorch::extension::TensorPtr& input,
90+
const executorch::extension::TensorPtr& start_pos,
91+
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) = 0;
92+
93+
/**
94+
* @brief Process and update internal state with outputs from the prefill
95+
* phase.
96+
*
97+
* @param prefill_method The prefill method to update with outputs.
98+
* @param model_outputs Vector of outputs from the prefill method execution.
99+
*/
100+
ET_NODISCARD virtual runtime::Error update_prefill(
101+
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
102+
const std::vector<executorch::runtime::EValue>& model_outputs) = 0;
103+
104+
/**
105+
* @brief Process and update internal state with outputs from the decode
106+
* phase.
107+
*
108+
* @param decode_method The decode method to update with outputs.
109+
* @param model_outputs Vector of outputs from the decode method execution.
110+
*/
111+
ET_NODISCARD virtual runtime::Error update_decode(
112+
const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
113+
const std::vector<executorch::runtime::EValue>& model_outputs) = 0;
114+
};
115+
116+
} // namespace llm
117+
} // namespace extension
118+
} // namespace executorch
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
for aten in (True, False):
5+
aten_suffix = "_aten" if aten else ""
6+
7+
# Interface for IOManager. No concrete impl from this dep.
8+
runtime.cxx_library(
9+
name = "io_manager" + aten_suffix,
10+
exported_headers = [
11+
"io_manager.h",
12+
],
13+
deps = [
14+
"//executorch/extension/module:module" + aten_suffix,
15+
"//executorch/extension/tensor:tensor" + aten_suffix,
16+
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
17+
"//executorch/runtime/executor:program_no_prim_ops" + aten_suffix,
18+
],
19+
visibility = [
20+
"@EXECUTORCH_CLIENTS",
21+
],
22+
)

0 commit comments

Comments
 (0)