Skip to content

Commit e515079

Browse files
IOManager Interface (#10418)
Summary: Hopefully this is sufficient for the contract. Going to do 2 follow up tests. Add a basic cpu implementation add a static attention implementation. Differential Revision: D73450877 cc @larryliu0820 @mergennachin @cccclai @helunwencser @jackzhxng
1 parent f35de65 commit e515079

File tree

7 files changed

+454
-0
lines changed

7 files changed

+454
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <vector>
12+
13+
#include <executorch/extension/tensor/tensor.h>
14+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15+
#include <executorch/runtime/executor/method.h>
16+
#include <executorch/runtime/executor/method_meta.h>
17+
18+
namespace executorch {
19+
namespace extension {
20+
namespace llm {
21+
22+
/**
23+
* @brief Base class for managing input/output operations for LLM inference.
24+
*
25+
* IOManager provides an interface for handling the input preparation and
26+
* output processing for both prefill and decode phases of LLM inference.
27+
* Derived classes must implement the virtual methods to provide specific IO
28+
* management functionality.
29+
*/
30+
class ET_EXPERIMENTAL IOManager {
31+
public:
32+
/**
33+
* @brief Virtual destructor to allow proper cleanup in derived classes.
34+
*/
35+
virtual ~IOManager() = default;
36+
37+
/**
38+
* @brief Load the IO manager with method metadata for prefill and
39+
* decode operations.
40+
*
41+
* @param program The program prefill and decode methods are loaded from.
42+
* @param prefill_method The prefill method to initialize with.
43+
* @param decode_method The decode method to initialize with.
44+
*/
45+
ET_NODISCARD virtual runtime::Error load(
46+
const executorch::ET_RUNTIME_NAMESPACE::Program& program,
47+
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
48+
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
49+
(void)program;
50+
(void)prefill_method;
51+
(void)decode_method;
52+
return runtime::Error::Ok;
53+
}
54+
55+
/**
56+
* @brief Reset the IO manager state.
57+
*
58+
* @param prefill_method The prefill method to reset with.
59+
* @param decode_method The decode method to reset with.
60+
*/
61+
ET_NODISCARD virtual runtime::Error reset(
62+
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
63+
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
64+
(void)prefill_method;
65+
(void)decode_method;
66+
return runtime::Error::Ok;
67+
}
68+
69+
/**
70+
* @brief Prepare inputs for the prefill phase of LLM inference.
71+
*
72+
* @param input The input tensor containing token IDs.
73+
* @param start_pos The tensor containing the starting position of the current
74+
* input within the context.
75+
* @param prefill_method The prefill method to prepare inputs for.
76+
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
77+
* for the prefill method.
78+
*/
79+
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
80+
prepare_prefill(
81+
const executorch::extension::TensorPtr& input,
82+
const executorch::extension::TensorPtr& start_pos,
83+
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) {
84+
if (prefill_method.inputs_size() != 2) {
85+
ET_LOG(
86+
Error,
87+
"Expected 2 inputs for prefill method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.",
88+
prefill_method.inputs_size());
89+
return runtime::Error::InvalidState;
90+
}
91+
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
92+
// here.
93+
return std::vector<runtime::EValue>{input, start_pos};
94+
}
95+
96+
/**
97+
* @brief Prepare inputs for the decode phase of LLM inference.
98+
*
99+
* @param input The input tensor containing token IDs.
100+
* @param start_pos The tensor containing the starting position of the current
101+
* input within the context.
102+
* @param decode_method The decode method to prepare inputs for.
103+
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
104+
* for the decode method.
105+
*/
106+
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
107+
prepare_decode(
108+
const executorch::extension::TensorPtr& input,
109+
const executorch::extension::TensorPtr& start_pos,
110+
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
111+
if (decode_method.inputs_size() != 2) {
112+
ET_LOG(
113+
Error,
114+
"Expected 2 inputs for decode method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.",
115+
decode_method.inputs_size());
116+
return runtime::Error::InvalidState;
117+
}
118+
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
119+
// here.
120+
return std::vector<runtime::EValue>{input, start_pos};
121+
}
122+
123+
/**
124+
* @brief Process and update internal state with outputs from the prefill
125+
* phase.
126+
*
127+
* @param prefill_method The prefill method to update with outputs.
128+
* @param model_outputs Vector of outputs from the prefill method execution.
129+
*/
130+
ET_NODISCARD virtual runtime::Error update_prefill(
131+
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
132+
const std::vector<executorch::runtime::EValue>& model_outputs) {
133+
(void)prefill_method;
134+
(void)model_outputs;
135+
// No post inference work to do.
136+
return runtime::Error::Ok;
137+
}
138+
139+
/**
140+
* @brief Process and update internal state with outputs from the decode
141+
* phase.
142+
*
143+
* @param decode_method The decode method to update with outputs.
144+
* @param model_outputs Vector of outputs from the decode method execution.
145+
*/
146+
ET_NODISCARD virtual runtime::Error update_decode(
147+
const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
148+
const std::vector<executorch::runtime::EValue>& model_outputs) {
149+
(void)decode_method;
150+
(void)model_outputs;
151+
// No post inference work to do.
152+
return runtime::Error::Ok;
153+
}
154+
};
155+
156+
} // namespace llm
157+
} // namespace extension
158+
} // namespace executorch
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
5+
for aten in (True, False):
6+
aten_suffix = "_aten" if aten else ""
7+
8+
# Interface for IOManager. No concrete impl from this dep.
9+
runtime.cxx_library(
10+
name = "io_manager" + aten_suffix,
11+
exported_headers = [
12+
"io_manager.h",
13+
],
14+
deps = [
15+
"//executorch/extension/tensor:tensor" + aten_suffix,
16+
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
17+
"//executorch/runtime/executor:program_no_prim_ops" + aten_suffix,
18+
],
19+
visibility = [
20+
"@EXECUTORCH_CLIENTS",
21+
],
22+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
5+
load(":targets.bzl", "define_common_targets")
6+
7+
oncall("executorch")
8+
9+
define_common_targets()
10+
11+
runtime.cxx_test(
12+
name = "test_io_manager",
13+
srcs = ["test_io_manager.cpp"],
14+
deps = [
15+
"//executorch/extension/llm/runner/io_manager:io_manager",
16+
"//executorch/extension/llm/runner/io_manager:io_manager",
17+
"//executorch/extension/module:module",
18+
"//executorch/extension/tensor:tensor",
19+
"//executorch/runtime/executor:program",
20+
"//executorch/kernels/portable:generated_lib",
21+
],
22+
env = {
23+
"KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])",
24+
}
25+
)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
8+
9+
def define_common_targets():
10+
pass

0 commit comments

Comments
 (0)