Skip to content

Commit 9a698d7

Browse files
committed
Update on "[llm] Support different shape of input_pos"
For huggingface models, `forward()` is taking `tokens` as well as `cache_positions`, which is a list of cache indices. This is different than the .pte files `export_llama` gives, which are taking `tokens` and `input_pos` where `input_pos` is a scalar tensor. This PR adds support inside `text_decoder_runner.cpp` to handle both shapes of `input_pos`/`cache_positions`. To make the logic more generic without relying on extra metadata, here I'm adding the logic of inspecting method meta and input tensor info, to make a decision if we want to feed in `input_pos` or `cache_position`. Differential Revision: [D77203700](https://our.internmc.facebook.com/intern/diff/D77203700/) [ghstack-poisoned]
2 parents 2272fc5 + 5525c1f commit 9a698d7

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

extension/llm/runner/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ add_subdirectory(
5656
set(runner_deps executorch_core extension_module extension_tensor tokenizers)
5757

5858
target_link_libraries(extension_llm_runner PUBLIC ${runner_deps})
59+
set_target_properties(extension_llm_runner PROPERTIES POSITION_INDEPENDENT_CODE ON)
5960

6061
target_include_directories(
6162
extension_llm_runner

kernels/portable/cpu/util/arange_util.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@ namespace torch::executor::native {
1919

2020
Tensor::SizesType
2121
compute_arange_out_size(double start, double end, double step) {
22-
ET_CHECK_MSG(
23-
end > start, "end (%f) must be greater than start (%f)", end, start);
24-
ET_CHECK_MSG(step > 0, "step must be positive, got %f", step);
2522
Tensor::SizesType numel =
2623
static_cast<Tensor::SizesType>(std::ceil((end - start) / step));
24+
25+
ET_CHECK_MSG(
26+
numel >= 0,
27+
"numel should be non-negative, but got (%d). start (%f), end (%f), step (%f)",
28+
numel,
29+
start,
30+
end,
31+
step);
2732
return numel;
2833
}
2934

0 commit comments

Comments
 (0)