Skip to content

Commit 5525c1f

Browse files
committed
Update base for Update on "[llm] Support different shape of input_pos"
For huggingface models, `forward()` is taking `tokens` as well as `cache_positions`, which is a list of cache indices. This is different than the .pte files `export_llama` gives, which are taking `tokens` and `input_pos` where `input_pos` is a scalar tensor. This PR adds support inside `text_decoder_runner.cpp` to handle both shapes of `input_pos`/`cache_positions`. To make the logic more generic without relying on extra metadata, here I'm adding the logic of inspecting method meta and input tensor info, to make a decision if we want to feed in `input_pos` or `cache_position`. Differential Revision: [D77203700](https://our.internmc.facebook.com/intern/diff/D77203700/) [ghstack-poisoned]
1 parent 3be9abf commit 5525c1f

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

kernels/portable/cpu/util/arange_util.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@ namespace torch::executor::native {
1919

2020
Tensor::SizesType
2121
compute_arange_out_size(double start, double end, double step) {
22-
ET_CHECK_MSG(
23-
end > start, "end (%f) must be greater than start (%f)", end, start);
24-
ET_CHECK_MSG(step > 0, "step must be positive, got %f", step);
2522
Tensor::SizesType numel =
2623
static_cast<Tensor::SizesType>(std::ceil((end - start) / step));
24+
25+
ET_CHECK_MSG(
26+
numel >= 0,
27+
"numel should be non-negative, but got (%d). start (%f), end (%f), step (%f)",
28+
numel,
29+
start,
30+
end,
31+
step);
2732
return numel;
2833
}
2934

0 commit comments

Comments
 (0)