Skip to content

Commit 9481d79

Browse files
committed
Update on "[llm] Support different shape of input_pos"
For huggingface models, `forward()` is taking `tokens` as well as `cache_positions`, which is a list of cache indices. This is different than the .pte files `export_llama` gives, which are taking `tokens` and `input_pos` where `input_pos` is a scalar tensor. This PR adds support inside `text_decoder_runner.cpp` to handle both shapes of `input_pos`/`cache_positions`. To make the logic more generic without relying on extra metadata, here I'm adding the logic of inspecting method meta and input tensor info, to make a decision if we want to feed in `input_pos` or `cache_position`. Differential Revision: [D77203700](https://our.internmc.facebook.com/intern/diff/D77203700/) [ghstack-poisoned]
2 parents 64aab38 + bb9d224 commit 9481d79

File tree

3 files changed

+19
-20
lines changed

3 files changed

+19
-20
lines changed

extension/llm/runner/test/TARGETS

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,22 @@
88
# targets.bzl. This file can contain fbcode-only targets.
99

1010
load(":targets.bzl", "define_common_targets")
11-
11+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
1212
oncall("executorch")
1313

1414
define_common_targets()
15+
16+
runtime.cxx_test(
17+
name = "test_text_decoder_runner",
18+
srcs = ["test_text_decoder_runner.cpp"],
19+
deps = [
20+
"//executorch/extension/llm/runner:runner_lib",
21+
"//executorch/kernels/portable:generated_lib",
22+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
23+
],
24+
env = {
25+
"KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])",
26+
"KVCACHE_INPUT_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheInputPos.pte])",
27+
"NO_KVCACHE": "$(location fbcode//executorch/test/models:exported_programs[ModuleNoKVCache.pte])",
28+
}
29+
)

extension/llm/runner/test/targets.bzl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,3 @@ def define_common_targets():
3636
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
3737
],
3838
)
39-
40-
runtime.cxx_test(
41-
name = "test_text_decoder_runner",
42-
srcs = ["test_text_decoder_runner.cpp"],
43-
deps = [
44-
"//executorch/extension/llm/runner:runner_lib",
45-
"//executorch/kernels/portable:generated_lib",
46-
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
47-
],
48-
env = {
49-
"KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])",
50-
"KVCACHE_INPUT_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheInputPos.pte])",
51-
"NO_KVCACHE": "$(location fbcode//executorch/test/models:exported_programs[ModuleNoKVCache.pte])",
52-
}
53-
)

extension/llm/runner/test/test_text_prefiller.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ using executorch::runtime::testing::TensorFactory;
2424
// Mock class for TextDecoderRunner
2525
class MockTextDecoderRunner : public TextDecoderRunner {
2626
public:
27-
MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {}
27+
MockTextDecoderRunner() : TextDecoderRunner(nullptr) {}
2828
MOCK_METHOD(
2929
Result<executorch::aten::Tensor>,
3030
step,
31-
(executorch::extension::TensorPtr&, executorch::extension::TensorPtr&),
31+
(executorch::extension::TensorPtr&, int64_t),
3232
());
3333
MOCK_METHOD(bool, is_method_loaded, (), ());
3434
MOCK_METHOD(Result<uint64_t>, prefill, (std::vector<uint64_t>&, int64_t), ());
@@ -44,8 +44,7 @@ class TextPrefillerTest : public Test {
4444
ON_CALL(text_decoder_runner_, is_method_loaded())
4545
.WillByDefault(Return(true));
4646
ON_CALL(text_decoder_runner_, step)
47-
.WillByDefault([&](executorch::extension::TensorPtr&,
48-
executorch::extension::TensorPtr&) {
47+
.WillByDefault([&](executorch::extension::TensorPtr&, int64_t) {
4948
return Result<executorch::aten::Tensor>(tensor);
5049
});
5150
}

0 commit comments

Comments
 (0)