Skip to content

Commit 82608bf

Browse files
authored
Gh/larryliu0820/46/base (#4643)
* [llava][13/N] Move metadata util to a separate header for reuse As titled. [ghstack-poisoned] * Update on "[llava][13/N] Move metadata util to a separate header for reuse" As titled. [ghstack-poisoned] * Update on "[llava][13/N] Move metadata util to a separate header for reuse" As titled. [ghstack-poisoned] * Update on "[llava][13/N] Move metadata util to a separate header for reuse" As titled. [ghstack-poisoned] * Update on "[llava][13/N] Move metadata util to a separate header for reuse" As titled. [ghstack-poisoned] * [llava][14/N] Refactor runner prefill() and run_model_step() This refactoring is needed in order to extract out prefill() and run_model_step() out from runner so that these APIs become replaceable and easy to plugin and use. * prefill(): For the case where parallel prefill is enabled or not using kv cache, the model is able to accept a large block (more than 1) of tokens. For the other case where we have kv cache but parallel prefill is not enabled, we can only feed in 1 token every time. * run_model_step(): This function should not update the input. Instead it should run the model differently, depending on whether kv cache is enabled. This should return the next token directly. All the input update needs to happen in the generation loop. [ghstack-poisoned] * Update on "[llava][14/N] Refactor runner prefill() and run_model_step()" This refactoring is needed in order to extract out prefill() and run_model_step() out from runner so that these APIs become replaceable and easy to plugin and use. * prefill(): For the case where parallel prefill is enabled or not using kv cache, the model is able to accept a large block (more than 1) of tokens. For the other case where we have kv cache but parallel prefill is not enabled, we can only feed in 1 token every time. * run_model_step(): This function should not update the input. Instead it should run the model differently, depending on whether kv cache is enabled. This should return the next token directly. All the input update needs to happen in the generation loop. [ghstack-poisoned] * Update on "[llava][14/N] Refactor runner prefill() and run_model_step()" This refactoring is needed in order to extract out prefill() and run_model_step() out from runner so that these APIs become replaceable and easy to plugin and use. * prefill(): For the case where parallel prefill is enabled or not using kv cache, the model is able to accept a large block (more than 1) of tokens. For the other case where we have kv cache but parallel prefill is not enabled, we can only feed in 1 token every time. * run_model_step(): This function should not update the input. Instead it should run the model differently, depending on whether kv cache is enabled. This should return the next token directly. All the input update needs to happen in the generation loop. [ghstack-poisoned] * [llava][15/N] Extract out text decoder runner Last PR #4556 refactored run_model_step() so that it is suitable to be extracted out as a separate class. This new `TextDecoderRunner` provides 2 APIs: * step(tokens, start_pos) This API takes one or more tokens with start_pos and feed them into Module. Return a tensor of logits. * logits_to_token(logits) This API samples the result and returns a token. We don't expect this logic to change across different runners. [ghstack-poisoned] * Update on "[llava][14/N] Refactor runner prefill() and run_model_step()" This refactoring is needed in order to extract out prefill() and run_model_step() out from runner so that these APIs become replaceable and easy to plugin and use. * prefill(): For the case where parallel prefill is enabled or not using kv cache, the model is able to accept a large block (more than 1) of tokens. For the other case where we have kv cache but parallel prefill is not enabled, we can only feed in 1 token every time. * run_model_step(): This function should not update the input. Instead it should run the model differently, depending on whether kv cache is enabled. This should return the next token directly. All the input update needs to happen in the generation loop. Differential Revision: [D60840327](https://our.internmc.facebook.com/intern/diff/D60840327) [ghstack-poisoned] * Update on "[llava][15/N] Extract out text decoder runner" Last PR #4556 refactored run_model_step() so that it is suitable to be extracted out as a separate class. This new `TextDecoderRunner` provides 2 APIs: * step(tokens, start_pos) This API takes one or more tokens with start_pos and feed them into Module. Return a tensor of logits. * logits_to_token(logits) This API samples the result and returns a token. We don't expect this logic to change across different runners. Differential Revision: [D60856571](https://our.internmc.facebook.com/intern/diff/D60856571) [ghstack-poisoned] * Update base for Update on "[llava][15/N] Extract out text decoder runner" Last PR #4556 refactored run_model_step() so that it is suitable to be extracted out as a separate class. This new `TextDecoderRunner` provides 2 APIs: * step(tokens, start_pos) This API takes one or more tokens with start_pos and feed them into Module. Return a tensor of logits. * logits_to_token(logits) This API samples the result and returns a token. We don't expect this logic to change across different runners. Differential Revision: [D60856571](https://our.internmc.facebook.com/intern/diff/D60856571) [ghstack-poisoned] * [llava][16/N] Extract out prefill logic into a new class Depends on whether parallel or sequential prefill is chosen, prefill() calls `TextDecoderRunner.step()` to prefill prompt tokens to LLM. [ghstack-poisoned] * Update base for Update on "[llava][16/N] Extract out prefill logic into a new class" Depends on whether parallel or sequential prefill is chosen, prefill() calls `TextDecoderRunner.step()` to prefill prompt tokens to LLM. Differential Revision: [D60927756](https://our.internmc.facebook.com/intern/diff/D60927756) [ghstack-poisoned] * Update base for Update on "[llava][16/N] Extract out prefill logic into a new class" Depends on whether parallel or sequential prefill is chosen, prefill() calls `TextDecoderRunner.step()` to prefill prompt tokens to LLM. Differential Revision: [D60927756](https://our.internmc.facebook.com/intern/diff/D60927756) [ghstack-poisoned] * Update base for Update on "[llava][16/N] Extract out prefill logic into a new class" Depends on whether parallel or sequential prefill is chosen, prefill() calls `TextDecoderRunner.step()` to prefill prompt tokens to LLM. Differential Revision: [D60927756](https://our.internmc.facebook.com/intern/diff/D60927756) [ghstack-poisoned] * Update base for Update on "[llava][17/N] Move util.h into /e/llm/runner" So that it can be reused Differential Revision: [D60938984](https://our.internmc.facebook.com/intern/diff/D60938984) [ghstack-poisoned] * Update base for Update on "[llava][17/N] Move util.h into /e/llm/runner" So that it can be reused Differential Revision: [D60938984](https://our.internmc.facebook.com/intern/diff/D60938984) [ghstack-poisoned] * [llava][17/N] Move util.h into /e/llm/runner Differential Revision: D60938984 Pull Request resolved: #4588
1 parent c04bc99 commit 82608bf

File tree

10 files changed

+447
-293
lines changed

10 files changed

+447
-293
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 61 additions & 273 deletions
Large diffs are not rendered by default.

examples/models/llama2/runner/runner.h

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include <unordered_map>
2020

2121
#include <executorch/extension/llm/runner/stats.h>
22+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
23+
#include <executorch/extension/llm/runner/text_prefiller.h>
2224
#include <executorch/extension/llm/sampler/sampler.h>
2325
#include <executorch/extension/llm/tokenizer/tokenizer.h>
2426
#include <executorch/extension/module/module.h>
@@ -44,17 +46,6 @@ class Runner {
4446
void stop();
4547

4648
private:
47-
int32_t logitsToToken(const exec_aten::Tensor& logits_tensor);
48-
Result<exec_aten::Tensor> prefill(
49-
const std::vector<uint64_t>& tokens,
50-
ManagedTensor& managed_tokens,
51-
ManagedTensor& managed_start_pos,
52-
std::function<void(const std::string&)> token_callback);
53-
Result<exec_aten::Tensor> run_model_step(
54-
int64_t input_token,
55-
ManagedTensor& tokens,
56-
ManagedTensor& start_pos,
57-
size_t max_seq_len);
5849
// metadata
5950
int32_t vocab_size_;
6051
int32_t bos_id_;
@@ -65,16 +56,21 @@ class Runner {
6556
bool use_kv_cache_;
6657
bool use_sdpa_with_kv_cache_;
6758
bool append_eos_;
59+
float temperature_;
60+
bool enable_parallel_prefill_;
61+
bool shouldStop_{false};
62+
63+
// model
6864
std::unordered_set<std::string> model_methods_;
6965
std::string model_path_;
7066
std::unique_ptr<Module> module_;
67+
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
68+
std::unique_ptr<TextPrefiller> text_prefiller_;
7169
std::string tokenizer_path_;
72-
float temperature_;
7370
std::unique_ptr<Tokenizer> tokenizer_;
74-
std::unique_ptr<Sampler> sampler_;
75-
bool shouldStop_{false};
71+
72+
// stats
7673
Stats stats_;
77-
bool enable_parallel_prefill_;
7874
};
7975

8076
} // namespace torch::executor

examples/models/llama2/runner/targets.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def define_common_targets():
2222
],
2323
exported_headers = [
2424
"runner.h",
25-
"util.h",
2625
],
2726
preprocessor_flags = [
2827
"-DUSE_ATEN_LIB",
@@ -34,7 +33,8 @@ def define_common_targets():
3433
exported_deps = [
3534
"//executorch/backends/xnnpack:xnnpack_backend",
3635
"//executorch/extension/llm/runner:stats",
37-
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
36+
"//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix,
37+
"//executorch/extension/llm/runner:text_prefiller" + aten_suffix,
3838
"//executorch/extension/evalue_util:print_evalue" + aten_suffix,
3939
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
4040
"//executorch/extension/module:module" + aten_suffix,

extension/llm/runner/targets.bzl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,44 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
33
def define_common_targets():
44
runtime.cxx_library(
55
name = "stats",
6-
exported_headers = ["stats.h"],
6+
exported_headers = [
7+
"stats.h",
8+
"util.h",
9+
],
710
visibility = [
811
"@EXECUTORCH_CLIENTS",
912
],
1013
)
14+
15+
for aten in (True, False):
16+
aten_suffix = "_aten" if aten else ""
17+
18+
runtime.cxx_library(
19+
name = "text_decoder_runner" + aten_suffix,
20+
exported_headers = ["text_decoder_runner.h"],
21+
srcs = ["text_decoder_runner.cpp"],
22+
visibility = [
23+
"@EXECUTORCH_CLIENTS",
24+
],
25+
exported_deps = [
26+
":stats",
27+
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
28+
"//executorch/extension/module:module" + aten_suffix,
29+
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
30+
],
31+
)
32+
33+
runtime.cxx_library(
34+
name = "text_prefiller" + aten_suffix,
35+
exported_headers = ["text_prefiller.h"],
36+
srcs = ["text_prefiller.cpp"],
37+
visibility = [
38+
"@EXECUTORCH_CLIENTS",
39+
],
40+
exported_deps = [
41+
":text_decoder_runner" + aten_suffix,
42+
"//executorch/extension/llm/tokenizer:tokenizer_header",
43+
"//executorch/extension/module:module" + aten_suffix,
44+
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
45+
],
46+
)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Given inputs, run a text decoder and return logits.
10+
11+
#include <executorch/extension/llm/runner/stats.h>
12+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
13+
#include <ctime>
14+
15+
namespace torch::executor {
16+
17+
// NOTE: we observed ~2x loading performance increase on iPhone 15
18+
// and a ~5% improvement on Galaxy S22 by switching to
19+
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
20+
TextDecoderRunner::TextDecoderRunner(
21+
Module* module,
22+
bool use_kv_cache,
23+
int32_t vocab_size,
24+
float temperature)
25+
: module_(module),
26+
sampler_(std::make_unique<Sampler>(
27+
vocab_size,
28+
temperature,
29+
::executorch::llm::kTopp,
30+
static_cast<unsigned long long>(std::time(nullptr)))),
31+
use_kv_cache_(use_kv_cache) {}
32+
33+
// This function is functional, meaning it shouldn't modify any state of the
34+
// input. It should be safe to call multiple times with the same inputs. The
35+
// outer loop (call site) is responsible for managing state.
36+
Result<exec_aten::Tensor> TextDecoderRunner::step(
37+
ManagedTensor& managed_tokens,
38+
ManagedTensor& managed_start_pos) {
39+
auto tokens = managed_tokens.get_aliasing_tensor();
40+
// ET_LOG(Info, "Input token %" PRIu64, input_token);
41+
if (use_kv_cache_) {
42+
auto start_pos = managed_start_pos.get_aliasing_tensor();
43+
Result<std::vector<EValue>> outputs_res =
44+
module_->forward({tokens, start_pos});
45+
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
46+
ET_CHECK_MSG(
47+
outputs_res.get().size() == 1,
48+
"More then one output returned from executing LLM.");
49+
ET_CHECK_MSG(
50+
outputs_res.get()[0].isTensor(),
51+
"Non Tensor Output returned from executing LLM");
52+
53+
// Return the logits tensor
54+
return outputs_res.get()[0].toTensor();
55+
} else { // no kv cache
56+
(void)managed_start_pos; // unused
57+
58+
Result<std::vector<EValue>> outputs_res = module_->forward({tokens});
59+
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
60+
ET_CHECK_MSG(
61+
outputs_res.get().size() == 1,
62+
"More then one output returned from executing LLM.");
63+
ET_CHECK_MSG(
64+
outputs_res.get()[0].isTensor(),
65+
"Non Tensor Output returned from executing LLM");
66+
67+
// Return the logits tensor
68+
return outputs_res.get()[0].toTensor();
69+
}
70+
}
71+
72+
} // namespace torch::executor
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Given inputs, run a text decoder in LLM and return the output.
10+
11+
#pragma once
12+
13+
#include <executorch/extension/llm/sampler/sampler.h>
14+
#include <executorch/extension/module/module.h>
15+
#include <executorch/extension/runner_util/managed_tensor.h>
16+
// patternlint-disable-next-line executorch-cpp-nostdinc
17+
#include <functional>
18+
19+
namespace torch::executor {
20+
21+
class TextDecoderRunner {
22+
public:
23+
TextDecoderRunner(
24+
Module* module,
25+
bool use_kv_cache,
26+
int32_t vocab_size,
27+
float temperature);
28+
/**
29+
* Run LLM text decoder with inputs to generate next token.
30+
* @param input The input to the LLM Module.
31+
* @param start_pos The starting position in KV cache of the input in the LLM
32+
* Module.
33+
* @return The output of the LLM Module. This will be a tensor of logits.
34+
*/
35+
Result<exec_aten::Tensor> step(
36+
ManagedTensor& input,
37+
ManagedTensor& start_pos);
38+
39+
/**
40+
* Load the Module for a given method name.
41+
* @param method_name The name of the method to load.
42+
* @return The error code.
43+
*/
44+
inline Error load(const std::string& method_name = "forward") {
45+
return module_->load_method(method_name);
46+
}
47+
48+
/**
49+
* Check if the Module is loaded.
50+
* @return True if the Module is loaded, false otherwise.
51+
*/
52+
inline bool is_method_loaded(const std::string& method_name = "forward") {
53+
return module_->is_method_loaded(method_name);
54+
}
55+
56+
/**
57+
* Sample the next token from the logits tensor.
58+
* @param logits_tensor The logits tensor.
59+
* @return The next token.
60+
*/
61+
inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) {
62+
ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D");
63+
auto num_tokens = logits_tensor.size(1);
64+
auto vocab_size = logits_tensor.size(2);
65+
66+
switch (logits_tensor.scalar_type()) {
67+
case ScalarType::Float: {
68+
float* logits = logits_tensor.mutable_data_ptr<float>();
69+
float* logits_last = logits;
70+
logits_last += (num_tokens - 1) * vocab_size;
71+
return sampler_->sample(logits_last);
72+
}
73+
case ScalarType::Half: {
74+
exec_aten::Half* logits =
75+
logits_tensor.mutable_data_ptr<exec_aten::Half>();
76+
exec_aten::Half* logits_last = logits;
77+
logits_last += (num_tokens - 1) * vocab_size;
78+
return sampler_->sample(logits_last);
79+
}
80+
default:
81+
ET_CHECK_MSG(
82+
false,
83+
"Unsupported dtype output %hhd",
84+
static_cast<int8_t>(logits_tensor.scalar_type()));
85+
}
86+
}
87+
88+
protected:
89+
// TODO: use shared_ptr for module
90+
Module* module_;
91+
std::unique_ptr<Sampler> sampler_;
92+
bool use_kv_cache_;
93+
};
94+
95+
} // namespace torch::executor
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Given a text prompt, encode it using tokenizer and prefill the KV cache of a
10+
// LLM.
11+
12+
#include <executorch/extension/llm/runner/text_prefiller.h>
13+
14+
namespace torch::executor {
15+
16+
TextPrefiller::TextPrefiller(
17+
Tokenizer* tokenizer,
18+
TextDecoderRunner* text_decoder_runner,
19+
bool use_kv_cache,
20+
bool enable_parallel_prefill)
21+
: tokenizer_(tokenizer),
22+
text_decoder_runner_(text_decoder_runner),
23+
use_kv_cache_(use_kv_cache),
24+
enable_parallel_prefill_(enable_parallel_prefill) {}
25+
26+
Result<uint64_t> TextPrefiller::prefill(
27+
std::vector<uint64_t>& prompt_tokens,
28+
int64_t start_pos,
29+
std::function<void(const std::string&)> token_callback) {
30+
ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null");
31+
if (!text_decoder_runner_->is_method_loaded()) {
32+
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
33+
}
34+
// enable_parallel_prefill_ maybe set even when not using kv cache
35+
// When kv cache is not used, start pos is ignored
36+
int32_t num_prompt_tokens = prompt_tokens.size();
37+
38+
// store the token
39+
uint64_t cur_token;
40+
if (enable_parallel_prefill_ || !use_kv_cache_) {
41+
// initialize tensor wrappers
42+
ManagedTensor managed_tokens(
43+
prompt_tokens.data(), {1, num_prompt_tokens}, ScalarType::Long);
44+
45+
ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long);
46+
47+
Result<exec_aten::Tensor> outputs_res =
48+
text_decoder_runner_->step(managed_tokens, managed_start_pos);
49+
50+
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
51+
ET_LOG(
52+
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
53+
ET_CHECK_MSG(
54+
outputs_res.get().size(1) == num_prompt_tokens,
55+
"Expected number of output tokens %d does not match returned value %zu.",
56+
num_prompt_tokens,
57+
outputs_res.get().size(1));
58+
// insert new token into prompt_tokens
59+
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
60+
uint64_t prev = prompt_tokens[0];
61+
uint64_t cur;
62+
for (int i = 1; i < prompt_tokens.size(); i++) {
63+
cur = prompt_tokens[i];
64+
token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur)));
65+
prev = cur;
66+
}
67+
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
68+
} else { // sequential prefill
69+
int64_t pos = 0; // position in the sequence
70+
int64_t prev_token;
71+
// token & pos
72+
int64_t pos_data = 0;
73+
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
74+
cur_token = prompt_tokens[0];
75+
76+
// initialize tensor wrappers
77+
ManagedTensor managed_tokens(&cur_token, {1, 1}, ScalarType::Long);
78+
79+
ManagedTensor managed_start_pos(&pos_data, {1}, ScalarType::Long);
80+
81+
while (pos < num_prompt_tokens) {
82+
// Run the model
83+
pos_data = start_pos + pos;
84+
85+
Result<exec_aten::Tensor> logits_res =
86+
text_decoder_runner_->step(managed_tokens, managed_start_pos);
87+
88+
ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
89+
prev_token = cur_token;
90+
91+
pos++;
92+
93+
cur_token = pos == num_prompt_tokens
94+
? text_decoder_runner_->logits_to_token(logits_res.get())
95+
: prompt_tokens[pos];
96+
97+
// print the token as string, decode it with the Tokenizer object
98+
token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));
99+
}
100+
}
101+
return cur_token;
102+
}
103+
104+
} // namespace torch::executor

0 commit comments

Comments
 (0)