Skip to content

Commit 7324285

Browse files
committed
test
1 parent cb633cf commit 7324285

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

extension/llm/runner/text_decoder_runner.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,20 @@ ::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
5353
auto numel = sizes[0];
5454
std::vector<::executorch::aten::SizesType> sizes_vec = {numel};
5555

56-
auto start_pos_tensor = ET_UNWRAP(populate_start_pos_or_cache_position(
57-
module_, start_pos, tokens->numel()));
56+
TensorPtr start_pos_tensor;
57+
if (numel > 1) {
58+
// If we are here, model is exported with cache_positions, create a tensor
59+
// with the same length as input_ids. Assuming the last dimension is the
60+
// one with the variable token length, for example [1, S] or [1, 1, S]
61+
sizes_vec[sizes_vec.size() - 1] = tokens->numel();
62+
start_pos_tensor = empty(sizes_vec, ::executorch::aten::ScalarType::Long);
63+
torch::executor::native::arange_out_impl(
64+
start_pos, start_pos + tokens->numel(), 1.0, *start_pos_tensor);
65+
} else {
66+
// Assuming model is exported with input_pos, create a tensor with size 1
67+
start_pos_tensor = from_blob(
68+
&start_pos, sizes_vec, ::executorch::aten::ScalarType::Long);
69+
}
5870

5971
std::vector<runtime::EValue> inputs;
6072
auto inputs_res = io_manager_->prepare_decode(tokens, start_pos_tensor);

0 commit comments

Comments
 (0)