Skip to content

Commit 64aab38

Browse files
committed
Update on "[llm] Support different shape of input_pos"
For huggingface models, `forward()` is taking `tokens` as well as `cache_positions`, which is a list of cache indices. This is different than the .pte files `export_llama` gives, which are taking `tokens` and `input_pos` where `input_pos` is a scalar tensor. This PR adds support inside `text_decoder_runner.cpp` to handle both shapes of `input_pos`/`cache_positions`. To make the logic more generic without relying on extra metadata, here I'm adding the logic of inspecting method meta and input tensor info, to make a decision if we want to feed in `input_pos` or `cache_position`. Differential Revision: [D77203700](https://our.internmc.facebook.com/intern/diff/D77203700/) [ghstack-poisoned]
2 parents 3b95ef7 + 6661e13 commit 64aab38

File tree

4 files changed

+25
-23
lines changed

4 files changed

+25
-23
lines changed

extension/llm/runner/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def define_common_targets():
3434
],
3535
exported_deps = [
3636
":stats",
37-
"//executorch/kernels/portable/cpu/util:arange_util",
37+
"//executorch/kernels/portable/cpu/util:arange_util" + aten_suffix,
3838
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
3939
"//executorch/extension/module:module" + aten_suffix,
4040
"//executorch/extension/tensor:tensor" + aten_suffix,

kernels/portable/cpu/util/arange_util.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,21 @@ namespace torch::executor::native {
1212
#define ET_ARANGE_IMPL(ctx, start, numel, step, out, op_name) \
1313
ET_SWITCH_REALHBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE, [&]() { \
1414
auto out_data = out.mutable_data_ptr<CTYPE>(); \
15-
for (Tensor::SizesType i = 0; i < numel; ++i) { \
15+
for (executorch::aten::SizesType i = 0; i < numel; ++i) { \
1616
out_data[i] = static_cast<CTYPE>(start + i * step); \
1717
} \
1818
})
1919

20-
Tensor::SizesType
20+
executorch::aten::SizesType
2121
compute_arange_out_size(double start, double end, double step) {
22-
Tensor::SizesType numel =
23-
static_cast<Tensor::SizesType>(std::ceil((end - start) / step));
22+
executorch::aten::SizesType numel =
23+
static_cast<executorch::aten::SizesType>(std::ceil((end - start) / step));
2424

2525
ET_CHECK_MSG(
2626
numel >= 0,
27-
"numel should be non-negative, but got (%d). start (%f), end (%f), step (%f)",
28-
numel,
27+
"numel should be non-negative, but got (%" PRId64
28+
"). start (%f), end (%f), step (%f)",
29+
static_cast<int64_t>(numel),
2930
start,
3031
end,
3132
step);
@@ -39,7 +40,7 @@ void arange_out_impl(
3940
double step,
4041
Tensor& out) {
4142
(void)ctx;
42-
Tensor::SizesType numel = compute_arange_out_size(start, end, step);
43+
executorch::aten::SizesType numel = compute_arange_out_size(start, end, step);
4344
ET_ARANGE_IMPL(ctx, start, numel, step, out, "arange.start_out");
4445
}
4546

kernels/portable/cpu/util/arange_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
namespace torch::executor::native {
1414

15-
Tensor::SizesType
15+
executorch::aten::SizesType
1616
compute_arange_out_size(double start, double end, double step);
1717

18-
inline Tensor::SizesType compute_arange_out_size(double end) {
18+
inline executorch::aten::SizesType compute_arange_out_size(double end) {
1919
return compute_arange_out_size(0.0, end, 1.0);
2020
}
2121

kernels/portable/cpu/util/targets.bzl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -295,19 +295,6 @@ def define_common_targets():
295295
visibility = ["//executorch/kernels/portable/cpu/..."],
296296
)
297297

298-
runtime.cxx_library(
299-
name = "arange_util",
300-
srcs = ["arange_util.cpp"],
301-
exported_headers = ["arange_util.h"],
302-
exported_deps = [
303-
"//executorch/runtime/kernel:kernel_includes",
304-
],
305-
visibility = [
306-
"//executorch/kernels/portable/cpu/...",
307-
"//executorch/extension/llm/...",
308-
],
309-
)
310-
311298
runtime.cxx_library(
312299
name = "broadcast_indexes_range",
313300
exported_headers = ["broadcast_indexes_range.h"],
@@ -343,3 +330,17 @@ def define_common_targets():
343330
"@EXECUTORCH_CLIENTS",
344331
],
345332
)
333+
334+
335+
runtime.cxx_library(
336+
name = "arange_util{}".format(suffix),
337+
srcs = ["arange_util.cpp"],
338+
exported_headers = ["arange_util.h"],
339+
exported_deps = [
340+
"//executorch/runtime/kernel:kernel_includes{}".format(suffix),
341+
],
342+
visibility = [
343+
"//executorch/kernels/portable/cpu/...",
344+
"//executorch/extension/llm/...",
345+
],
346+
)

0 commit comments

Comments
 (0)