Skip to content

Commit 58e2792

Browse files
committed
[llm] Support different shape of input_pos
For huggingface models, `forward()` is taking `tokens` as well as `cache_positions`, which is a list of cache indices. This is different than the .pte files `export_llama` gives, which are taking `tokens` and `input_pos` where `input_pos` is a scalar tensor. This PR adds support inside `text_decoder_runner.cpp` to handle both shapes of `input_pos`/`cache_positions`. To make the logic more generic without relying on extra metadata, here I'm adding the logic of inspecting method meta and input tensor info, to make a decision if we want to feed in `input_pos` or `cache_position`. Differential Revision: [D77203700](https://our.internmc.facebook.com/intern/diff/D77203700/) [ghstack-poisoned]
1 parent daf808e commit 58e2792

File tree

12 files changed

+317
-30
lines changed

12 files changed

+317
-30
lines changed

extension/llm/runner/test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
1818
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
1919

2020
set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp
21-
test_text_prefiller.cpp
21+
test_text_prefiller.cpp test_text_decoder_runner.cpp
2222
)
2323

2424
et_cxx_test(

extension/llm/runner/test/targets.bzl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,18 @@ def define_common_targets():
3636
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
3737
],
3838
)
39+
40+
runtime.cxx_test(
41+
name = "test_text_decoder_runner",
42+
srcs = ["test_text_decoder_runner.cpp"],
43+
deps = [
44+
"//executorch/extension/llm/runner:runner_lib",
45+
"//executorch/kernels/portable:generated_lib",
46+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
47+
],
48+
env = {
49+
"KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])",
50+
"KVCACHE_INPUT_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheInputPos.pte])",
51+
"NO_KVCACHE": "$(location fbcode//executorch/test/models:exported_programs[ModuleNoKVCache.pte])",
52+
}
53+
)
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
8+
*/
9+
10+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
11+
#include <executorch/extension/module/module.h>
12+
#include <executorch/extension/tensor/tensor.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14+
#include <gmock/gmock.h>
15+
#include <gtest/gtest.h>
16+
#include <cstdlib>
17+
18+
using namespace ::testing;
19+
using executorch::extension::Module;
20+
using executorch::extension::TensorPtr;
21+
using executorch::extension::llm::TextDecoderRunner;
22+
using executorch::runtime::Error;
23+
using executorch::runtime::EValue;
24+
using executorch::runtime::Result;
25+
using executorch::runtime::testing::TensorFactory;
26+
27+
// Mock Module class for testing
28+
class MockModule : public Module {
29+
public:
30+
MockModule() : Module("") {}
31+
};
32+
33+
class TextDecoderRunnerTest : public Test {
34+
protected:
35+
void SetUp() override {
36+
mock_module_ = std::make_unique<MockModule>();
37+
runner_ = std::make_unique<TextDecoderRunner>(mock_module_.get());
38+
}
39+
40+
std::unique_ptr<MockModule> mock_module_;
41+
std::unique_ptr<TextDecoderRunner> runner_;
42+
};
43+
44+
// Test logits_to_token() method with Float tensor
45+
TEST_F(TextDecoderRunnerTest, LogitsToTokenFloat) {
46+
TensorFactory<executorch::aten::ScalarType::Float> tf_float;
47+
auto logits = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});
48+
49+
// Call logits_to_token with temperature 0 (deterministic)
50+
int32_t token = runner_->logits_to_token(logits, 0.0f);
51+
52+
// With temperature 0, should return the argmax (index 2)
53+
EXPECT_EQ(token, 2);
54+
}
55+
56+
// Test logits_to_token() method with 3D tensor (batch, seq_length, vocab_size)
57+
TEST_F(TextDecoderRunnerTest, LogitsToToken3D) {
58+
TensorFactory<executorch::aten::ScalarType::Float> tf_float;
59+
// Shape: [1, 2, 4] - batch=1, seq_length=2, vocab_size=4
60+
auto logits = tf_float.make(
61+
{1, 2, 4},
62+
{
63+
0.1f,
64+
0.2f,
65+
0.3f,
66+
0.4f, // First sequence position
67+
0.5f,
68+
0.6f,
69+
0.9f,
70+
0.8f // Second sequence position (last)
71+
});
72+
73+
// Call logits_to_token with temperature 0 (deterministic)
74+
int32_t token = runner_->logits_to_token(logits, 0.0f);
75+
76+
// Should use the last sequence position and return argmax (index 2)
77+
EXPECT_EQ(token, 2);
78+
}
79+
80+
// Test logits_to_token() method with Half tensor
81+
TEST_F(TextDecoderRunnerTest, LogitsToTokenHalf) {
82+
TensorFactory<executorch::aten::ScalarType::Half> tf_half;
83+
auto logits = tf_half.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});
84+
85+
// Call logits_to_token with temperature 0 (deterministic)
86+
int32_t token = runner_->logits_to_token(logits, 0.0f);
87+
88+
// With temperature 0, should return the argmax (index 2)
89+
EXPECT_EQ(token, 2);
90+
}
91+
92+
// Test logits_to_token() method with BFloat16 tensor
93+
TEST_F(TextDecoderRunnerTest, LogitsToTokenBFloat16) {
94+
TensorFactory<executorch::aten::ScalarType::BFloat16> tf_bfloat16;
95+
auto logits = tf_bfloat16.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});
96+
97+
// Call logits_to_token with temperature 0 (deterministic)
98+
int32_t token = runner_->logits_to_token(logits, 0.0f);
99+
100+
// With temperature 0, should return the argmax (index 2)
101+
EXPECT_EQ(token, 2);
102+
}
103+
104+
// Test logits_to_token() method with non-zero temperature
105+
TEST_F(TextDecoderRunnerTest, LogitsToTokenWithTemperature) {
106+
TensorFactory<executorch::aten::ScalarType::Float> tf_float;
107+
auto logits = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});
108+
109+
// Call logits_to_token with temperature > 0 (stochastic)
110+
int32_t token = runner_->logits_to_token(logits, 1.0f);
111+
112+
// With temperature > 0, result should be within valid range
113+
EXPECT_GE(token, 0);
114+
EXPECT_LT(token, 4);
115+
}
116+
117+
// Test step() method with all available PTE models
118+
TEST_F(TextDecoderRunnerTest, StepWithAllModels) {
119+
// List of all environment variables for PTE models
120+
std::vector<std::pair<std::string, const char*>> env_vars = {
121+
{"KVCACHE_CACHE_POS", "KVCACHE_CACHE_POS"},
122+
{"KVCACHE_INPUT_POS", "KVCACHE_INPUT_POS"},
123+
{"NO_KVCACHE", "NO_KVCACHE"}};
124+
125+
// Check if any environment variables are set up front
126+
bool any_env_set = false;
127+
for (const auto& [model_name, env_var] : env_vars) {
128+
if (std::getenv(env_var)) {
129+
any_env_set = true;
130+
break;
131+
}
132+
}
133+
134+
// Skip test if no environment variables are set
135+
if (!any_env_set) {
136+
GTEST_SKIP() << "No PTE model environment variables were set";
137+
}
138+
139+
bool any_model_tested = false;
140+
141+
// Loop through all available models
142+
for (const auto& [model_name, env_var] : env_vars) {
143+
const char* model_path = std::getenv(env_var);
144+
if (!model_path) {
145+
continue; // Skip if environment variable not set
146+
}
147+
148+
SCOPED_TRACE(
149+
"Testing model: " + model_name + " from " + std::string(model_path));
150+
151+
// Load the model
152+
auto module = std::make_unique<Module>(model_path);
153+
auto load_result = module->load();
154+
if (load_result != Error::Ok) {
155+
ADD_FAILURE() << "Failed to load model " << model_name << " from "
156+
<< model_path << " with error: " << (int)load_result;
157+
continue;
158+
}
159+
160+
// Create TextDecoderRunner
161+
TextDecoderRunner runner(module.get());
162+
auto runner_load_result = runner.load();
163+
ASSERT_EQ(runner_load_result, Error::Ok)
164+
<< "Failed to load runner for " << model_name;
165+
166+
// Verify method is loaded
167+
EXPECT_TRUE(runner.is_method_loaded())
168+
<< "Method not loaded for " << model_name;
169+
170+
// Create input tensor pointer
171+
172+
TensorFactory<executorch::aten::ScalarType::Long> tf_long;
173+
auto input_tokens_ =
174+
tf_long.make({1, 3}, {50, 7, 11}); // Single token input
175+
176+
auto input_ptr = std::make_shared<executorch::aten::Tensor>(input_tokens_);
177+
int64_t start_pos = 0;
178+
179+
// Call step() and verify result is ok
180+
auto result = runner.step(input_ptr, start_pos);
181+
ASSERT_TRUE(result.ok()) << "step() failed for " << model_name
182+
<< " with error: " << (int)result.error();
183+
184+
// Verify output tensor is valid
185+
auto output_tensor = result.get();
186+
EXPECT_GT(output_tensor.numel(), 0)
187+
<< "Output tensor empty for " << model_name;
188+
189+
// Test logits_to_token works
190+
int32_t token = runner.logits_to_token(output_tensor, 0.0f);
191+
EXPECT_GE(token, 0) << "Invalid token for " << model_name;
192+
193+
any_model_tested = true;
194+
}
195+
196+
// This should not happen since we checked environment variables up front
197+
ASSERT_TRUE(any_model_tested)
198+
<< "No models were tested despite environment variables being set";
199+
}

extension/llm/runner/test/test_text_llm_runner.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ class MockModule : public ::executorch::extension::Module {
6363

6464
class MockTextDecoderRunner : public TextDecoderRunner {
6565
public:
66-
MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {}
66+
MockTextDecoderRunner() : TextDecoderRunner(nullptr) {}
6767
MOCK_METHOD(
6868
Result<executorch::aten::Tensor>,
6969
step,
70-
(executorch::extension::TensorPtr&, executorch::extension::TensorPtr&),
70+
(executorch::extension::TensorPtr&, int64_t),
7171
());
7272
MOCK_METHOD(bool, is_method_loaded, (), ());
7373
MOCK_METHOD(Result<uint64_t>, prefill, (std::vector<uint64_t>&, int64_t), ());
@@ -134,8 +134,7 @@ class RunnerTest : public Test {
134134
std::unique_ptr<MockTextDecoderRunner> createMockTextDecoderRunner() {
135135
auto text_decoder_runner = std::make_unique<MockTextDecoderRunner>();
136136
ON_CALL(*text_decoder_runner, step)
137-
.WillByDefault([&](executorch::extension::TensorPtr&,
138-
executorch::extension::TensorPtr&) {
137+
.WillByDefault([&](executorch::extension::TensorPtr&, int64_t) {
139138
return Result<executorch::aten::Tensor>(tensor);
140139
});
141140
ON_CALL(*text_decoder_runner, is_method_loaded())

extension/llm/runner/text_decoder_runner.cpp

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,52 @@ namespace llm {
2121
// NOTE: we observed ~2x loading performance increase on iPhone 15
2222
// and a ~5% improvement on Galaxy S22 by switching to
2323
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
24-
TextDecoderRunner::TextDecoderRunner(Module* module, bool use_kv_cache)
25-
: module_(module), use_kv_cache_(use_kv_cache) {}
24+
TextDecoderRunner::TextDecoderRunner(Module* module) : module_(module) {}
2625

2726
// This function is functional, meaning it shouldn't modify any state of the
2827
// input. It should be safe to call multiple times with the same inputs. The
2928
// outer loop (call site) is responsible for managing state.
3029
::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
3130
TensorPtr& tokens,
32-
TensorPtr& start_pos) {
31+
int64_t start_pos) {
3332
// ET_LOG(Info, "Input token %" PRIu64, input_token);
34-
if (use_kv_cache_) {
35-
auto outputs_res = module_->forward({tokens, start_pos});
33+
auto method_meta = ET_UNWRAP(module_->method_meta("forward"));
34+
// If only 1 input, we are not using kv cache
35+
bool use_kv_cache = method_meta.num_inputs() > 1;
36+
37+
if (use_kv_cache) {
38+
// Size of the second argument. This could be either input_pos or
39+
// cache_positions
40+
41+
// Check if we are using cache positions instead of input pos.
42+
auto second_input_info = ET_UNWRAP(method_meta.input_tensor_meta(1));
43+
// For input_pos, numel is 1, for cache_positions, numel is max_seq_len
44+
auto sizes = second_input_info.sizes();
45+
auto numel = 1;
46+
std::vector<::executorch::aten::SizesType> sizes_vec;
47+
for (const auto& size : sizes) {
48+
sizes_vec.emplace_back(size);
49+
numel *= size;
50+
}
51+
// Assuming the last dimension is the one with the variable token length
52+
sizes_vec[sizes_vec.size() - 1] = -1;
53+
TensorPtr start_pos_tensor;
54+
if (numel > 1) {
55+
// Assuming model is exported with cache_positions, create a tensor with
56+
// the same size as cache_positions
57+
start_pos_tensor = arange(
58+
start_pos,
59+
start_pos + tokens->numel(),
60+
1,
61+
sizes_vec,
62+
::executorch::aten::ScalarType::Long);
63+
} else {
64+
// Assuming model is exported with input_pos, create a tensor with size 1
65+
start_pos_tensor =
66+
from_blob(&start_pos, {1}, ::executorch::aten::ScalarType::Long);
67+
}
68+
ET_LOG(Info, "Start pos tensor numel: %zu", start_pos_tensor->numel());
69+
auto outputs_res = module_->forward({tokens, start_pos_tensor});
3670
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
3771
ET_CHECK_MSG(
3872
outputs_res.get().size() == 1,

extension/llm/runner/text_decoder_runner.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace llm {
2121

2222
class ET_EXPERIMENTAL TextDecoderRunner {
2323
public:
24-
TextDecoderRunner(Module* module, bool use_kv_cache);
24+
TextDecoderRunner(Module* module);
2525

2626
virtual ~TextDecoderRunner() = default;
2727

@@ -34,7 +34,7 @@ class ET_EXPERIMENTAL TextDecoderRunner {
3434
*/
3535
virtual ::executorch::runtime::Result<executorch::aten::Tensor> step(
3636
TensorPtr& input,
37-
TensorPtr& start_pos);
37+
int64_t start_pos);
3838

3939
/**
4040
* Load the Module for text decode purpose.
@@ -101,7 +101,6 @@ class ET_EXPERIMENTAL TextDecoderRunner {
101101
* Module remains valid for the duration of TextDecoderRunner's usage.
102102
*/
103103
Module* module_;
104-
bool use_kv_cache_;
105104
bool should_stop_{false};
106105
};
107106

extension/llm/runner/text_llm_runner.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,7 @@ std::unique_ptr<TextLLMRunner> create_text_llm_runner(
393393

394394
// Create text_decoder_runner. Use a shared_ptr so that it can be shared with
395395
// TextPrefiller and TextTokenGenerator
396-
auto text_decoder_runner = std::make_unique<TextDecoderRunner>(
397-
module.get(), metadata.at(kUseKVCache));
396+
auto text_decoder_runner = std::make_unique<TextDecoderRunner>(module.get());
398397

399398
// Create text_prefiller
400399
auto text_prefiller = std::make_unique<TextPrefiller>(

extension/llm/runner/text_prefiller.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,7 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill_chunk(
8686
{1, num_prompt_tokens},
8787
executorch::aten::ScalarType::Long);
8888

89-
auto start_pos_tensor =
90-
from_blob(&start_pos, {1}, executorch::aten::ScalarType::Long);
91-
92-
auto outputs_res = text_decoder_runner_->step(tokens, start_pos_tensor);
89+
auto outputs_res = text_decoder_runner_->step(tokens, start_pos);
9390

9491
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
9592
ET_LOG(
@@ -106,13 +103,10 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill_chunk(
106103
auto tokens =
107104
from_blob(&cur_token, {1, 1}, executorch::aten::ScalarType::Long);
108105

109-
auto start_pos_tensor =
110-
from_blob(&start_pos, {1}, executorch::aten::ScalarType::Long);
111-
112106
// run the first token and get back logits tensor. Assuming the first token
113107
// is bos so don't callback.
114108
auto logits_tensor =
115-
ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor));
109+
ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos));
116110

117111
pos += 1; // start the loop from index 1
118112
start_pos += 1;
@@ -122,8 +116,7 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill_chunk(
122116
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
123117
cur_token = prompt_tokens[pos];
124118

125-
logits_tensor =
126-
ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor));
119+
logits_tensor = ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos));
127120

128121
pos++;
129122
start_pos++;

0 commit comments

Comments
 (0)