Skip to content

Commit ea4f004

Browse files
authored
Update cache position population and arg order for multimodal runner (#14225)
For voxtral and phi-3, we construct the cache_position_tensor like before; for llava, it will construct underneath so we pass in size 1.
1 parent 5a1c117 commit ea4f004

File tree

4 files changed

+60
-43
lines changed

4 files changed

+60
-43
lines changed

extension/llm/runner/multimodal_decoder_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class ET_EXPERIMENTAL MultimodalDecoderRunner
4848
&start_pos, {1}, executorch::aten::ScalarType::Long);
4949
// run text model
5050
auto outputs_res = ET_UNWRAP(
51-
module_->execute(kTextModelMethod, {start_pos_tensor, embeddings}));
51+
module_->execute(kTextModelMethod, {embeddings, start_pos_tensor}));
5252

5353
ET_CHECK_MSG(
5454
outputs_res.size() == 1,

extension/llm/runner/multimodal_prefiller.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,24 +91,22 @@ Result<uint64_t> MultimodalPrefiller::prefill(
9191
}
9292

9393
// 2. Run decoder model for prefill.
94-
// `cache_position` goes from start_pos to start_pos + encoder_output.size(1).
95-
// e.g. if start_pos = 2 and encoder_output.size(1) = 5,
96-
// cache_position_tensor should be [2, 3, 4, 5, 6].
94+
95+
// Get expected shape of cache position tensor, which should be the second
96+
// argument
97+
9798
int64_t seq_len = encoder_output.toTensor().size(1);
9899
if (seq_len == 0) {
99100
ET_LOG(Error, "The encoder returned an empty output.");
100101
return ::executorch::runtime::Error::InvalidState;
101102
}
102-
std::vector<int64_t> cache_positions(seq_len);
103-
for (int64_t i = 0; i < seq_len; ++i) {
104-
cache_positions[i] = start_pos + i;
105-
}
106-
auto cache_position_tensor = ::executorch::extension::from_blob(
107-
cache_positions.data(),
108-
{static_cast<int>(seq_len)},
109-
executorch::aten::ScalarType::Long);
103+
std::vector<int64_t> cache_positions;
104+
105+
auto cache_position_tensor = ET_UNWRAP(populate_start_pos_or_cache_position(
106+
module_, start_pos, cache_positions, seq_len, kTextModelMethod));
107+
110108
auto prefill_result = module_->execute(
111-
kTextModelMethod, {cache_position_tensor, encoder_output});
109+
kTextModelMethod, {encoder_output, cache_position_tensor});
112110
if (prefill_result.error() != ::executorch::runtime::Error::Ok) {
113111
return prefill_result.error();
114112
}

extension/llm/runner/text_decoder_runner.cpp

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,37 +36,11 @@ ::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
3636
// If only 1 input, we are not using kv cache
3737
bool use_kv_cache = method_meta.num_inputs() > 1;
3838

39+
std::vector<int64_t> cache_positions;
40+
3941
if (use_kv_cache) {
40-
// Size of the second argument. This could be either input_pos or
41-
// cache_positions
42-
43-
// Check if we are using cache positions instead of input pos.
44-
auto second_input_info = ET_UNWRAP(method_meta.input_tensor_meta(1));
45-
// For input_pos, numel is 1, for cache_positions, numel is max_seq_len
46-
auto sizes = second_input_info.sizes();
47-
// Assuming 1D tensor
48-
ET_CHECK_OR_RETURN_ERROR(
49-
sizes.size() == 1,
50-
InvalidProgram,
51-
"The second input tensor is not 1D tensor. Got dimension (%zu)",
52-
sizes.size());
53-
auto numel = sizes[0];
54-
std::vector<::executorch::aten::SizesType> sizes_vec = {numel};
55-
56-
TensorPtr start_pos_tensor;
57-
if (numel > 1) {
58-
// If we are here, model is exported with cache_positions, create a tensor
59-
// with the same length as input_ids. Assuming the last dimension is the
60-
// one with the variable token length, for example [1, S] or [1, 1, S]
61-
sizes_vec[sizes_vec.size() - 1] = tokens->numel();
62-
start_pos_tensor = empty(sizes_vec, ::executorch::aten::ScalarType::Long);
63-
torch::executor::native::arange_out_impl(
64-
start_pos, start_pos + tokens->numel(), 1.0, *start_pos_tensor);
65-
} else {
66-
// Assuming model is exported with input_pos, create a tensor with size 1
67-
start_pos_tensor = from_blob(
68-
&start_pos, sizes_vec, ::executorch::aten::ScalarType::Long);
69-
}
42+
auto start_pos_tensor = ET_UNWRAP(populate_start_pos_or_cache_position(
43+
module_, start_pos, cache_positions, tokens->numel(), "forward"));
7044

7145
std::vector<runtime::EValue> inputs;
7246
auto inputs_res = io_manager_->prepare_decode(tokens, start_pos_tensor);

extension/llm/runner/util.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
*/
88

99
#pragma once
10+
#include <executorch/extension/llm/runner/constants.h>
11+
#include <executorch/extension/llm/runner/multimodal_prefiller.h>
12+
#include <executorch/extension/tensor/tensor.h>
1013
#include <executorch/runtime/platform/compiler.h>
1114
#include <stdio.h>
1215
#include <time.h>
@@ -99,6 +102,48 @@ ET_EXPERIMENTAL size_t inline get_rss_bytes() {
99102
// when this changed.
100103
return 0;
101104
}
105+
106+
// Returns the cache position tensor, which can be either a single start_pos
107+
// (when the method_name [`text_decoder` or `forward`] expects a tensor with
108+
// size 1 because model will populate the cache position tensor underneath), or
109+
// a populated tensor for cache position, for the given start_pos and seq_len.
110+
inline runtime::Result<TensorPtr> populate_start_pos_or_cache_position(
111+
Module* module,
112+
int64_t& start_pos,
113+
std::vector<int64_t>& cache_positions_vec,
114+
int seq_len,
115+
const char* method_name = "forward") {
116+
// Get expected shape of cache position tensor, which should be the second
117+
// argument
118+
auto method_meta = ET_UNWRAP(module->method_meta(method_name));
119+
auto second_input_info = ET_UNWRAP(method_meta.input_tensor_meta(1));
120+
auto second_input_sizes = second_input_info.sizes();
121+
auto numel = second_input_sizes[0];
122+
123+
for (int i = 0; i < second_input_sizes.size(); ++i) {
124+
ET_LOG(Error, "second_input_sizes[%d] = %d", i, second_input_sizes[i]);
125+
}
126+
127+
TensorPtr start_pos_tensor;
128+
if (numel > 1) {
129+
// `cache_position` goes from start_pos to start_pos +
130+
// encoder_output.size(1). e.g. if start_pos = 2 and encoder_output.size(1)
131+
// = 5, cache_position_tensor should be [2, 3, 4, 5, 6].
132+
cache_positions_vec.resize(seq_len);
133+
for (int64_t i = 0; i < seq_len; ++i) {
134+
cache_positions_vec[i] = start_pos + i;
135+
}
136+
return ::executorch::extension::from_blob(
137+
cache_positions_vec.data(),
138+
{static_cast<int>(seq_len)},
139+
executorch::aten::ScalarType::Long);
140+
} else {
141+
// Cache position is size 1.
142+
return ::executorch::extension::from_blob(
143+
&start_pos, {1}, executorch::aten::ScalarType::Long);
144+
}
145+
}
146+
102147
} // namespace llm
103148
} // namespace extension
104149
} // namespace executorch

0 commit comments

Comments
 (0)