Skip to content

Commit 5665e9b

Browse files
committed
Update
1 parent 06cb42a commit 5665e9b

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

extension/llm/runner/multimodal_prefiller.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ Result<uint64_t> MultimodalPrefiller::prefill(
103103
std::vector<int64_t> cache_positions;
104104

105105
auto cache_position_tensor = ET_UNWRAP(populate_start_pos_or_cache_position(
106-
kTextModelMethod, module_, start_pos, cache_positions, seq_len));
106+
module_, start_pos, cache_positions, seq_len, kTextModelMethod));
107107

108108
auto prefill_result = module_->execute(
109109
kTextModelMethod, {encoder_output, cache_position_tensor});

extension/llm/runner/text_decoder_runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ ::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
4040

4141
if (use_kv_cache) {
4242
auto start_pos_tensor = ET_UNWRAP(populate_start_pos_or_cache_position(
43-
"forward", module_, start_pos, cache_positions, tokens->numel()));
43+
module_, start_pos, cache_positions, tokens->numel()), "forward");
4444

4545
std::vector<runtime::EValue> inputs;
4646
auto inputs_res = io_manager_->prepare_decode(tokens, start_pos_tensor);

extension/llm/runner/util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ ET_EXPERIMENTAL size_t inline get_rss_bytes() {
108108
// size 1 because model will populate the cache position tensor underneath), or
109109
// a populated tensor for cache position, for the given start_pos and seq_len.
110110
inline runtime::Result<TensorPtr> populate_start_pos_or_cache_position(
111-
const char* method_name,
112111
Module* module,
113112
int64_t& start_pos,
114113
std::vector<int64_t>& cache_positions_vec,
115-
int seq_len) {
114+
int seq_len,
115+
const char* method_name = "forward") {
116116
// Get expected shape of cache position tensor, which should be the second
117117
// argument
118118
auto method_meta = ET_UNWRAP(module->method_meta(method_name));

0 commit comments

Comments
 (0)