diff --git a/extension/llm/runner/text_decoder_runner.cpp b/extension/llm/runner/text_decoder_runner.cpp index e60a07bc50a..4293b2a08d8 100644 --- a/extension/llm/runner/text_decoder_runner.cpp +++ b/extension/llm/runner/text_decoder_runner.cpp @@ -52,22 +52,25 @@ ::executorch::runtime::Result TextDecoderRunner::step( auto numel = sizes[0]; std::vector<::executorch::aten::SizesType> sizes_vec = {numel}; - // Assuming the last dimension is the one with the variable token length, - // for example [1, S] or [1, 1, S] - sizes_vec[sizes_vec.size() - 1] = numel; TensorPtr start_pos_tensor; if (numel > 1) { - // Assuming model is exported with cache_positions, create a tensor with - // the same size as cache_positions + // If we are here, model is exported with cache_positions, create a tensor + // with the same length as input_ids. Assuming the last dimension is the + // one with the variable token length, for example [1, S] or [1, 1, S] + sizes_vec[sizes_vec.size() - 1] = tokens->numel(); start_pos_tensor = empty(sizes_vec, ::executorch::aten::ScalarType::Long); torch::executor::native::arange_out_impl( - start_pos, start_pos + numel, 1.0, *start_pos_tensor); + start_pos, start_pos + tokens->numel(), 1.0, *start_pos_tensor); } else { // Assuming model is exported with input_pos, create a tensor with size 1 start_pos_tensor = from_blob( &start_pos, sizes_vec, ::executorch::aten::ScalarType::Long); } - ET_LOG(Info, "Start pos tensor numel: %zu", start_pos_tensor->numel()); + ET_LOG( + Info, + "Start pos tensor numel: %zu, tokens numel: %zu", + start_pos_tensor->numel(), + tokens->numel()); auto outputs_res = module_->forward({tokens, start_pos_tensor}); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); ET_CHECK_MSG( diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index ce12506a05c..a02cd3d1bf4 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -21,7 +21,7 @@ class ET_EXPERIMENTAL TextPrefiller { public: TextPrefiller( TextDecoderRunner* text_decoder_runner, - bool use_kv_cache_, + bool use_kv_cache, bool enable_parallel_prefill, int64_t max_seq_len = 128);