Skip to content

Commit 9bf819c

Browse files
committed
Update cache position size for llava
1 parent a4b7de0 commit 9bf819c

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

extension/llm/runner/multimodal_prefiller.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ Result<uint64_t> MultimodalPrefiller::prefill(
9494
// `cache_position` goes from start_pos to start_pos + encoder_output.size(1).
9595
// e.g. if start_pos = 2 and encoder_output.size(1) = 5,
9696
// cache_position_tensor should be [2, 3, 4, 5, 6].
97+
auto method_meta = ET_UNWRAP(module_->method_meta(kTextModelMethod));
98+
auto first_input_info = ET_UNWRAP(method_meta.input_tensor_meta(0));
99+
auto sizes = first_input_info.sizes();
100+
auto numel = sizes[0];
101+
97102
int64_t seq_len = encoder_output.toTensor().size(1);
98103
if (seq_len == 0) {
99104
ET_LOG(Error, "The encoder returned an empty output.");
@@ -103,10 +108,13 @@ Result<uint64_t> MultimodalPrefiller::prefill(
103108
for (int64_t i = 0; i < seq_len; ++i) {
104109
cache_positions[i] = start_pos + i;
105110
}
106-
auto cache_position_tensor = ::executorch::extension::from_blob(
107-
cache_positions.data(),
108-
{static_cast<int>(seq_len)},
109-
executorch::aten::ScalarType::Long);
111+
auto cache_position_tensor = (numel > 1)
112+
? ::executorch::extension::from_blob(
113+
cache_positions.data(),
114+
{static_cast<int>(seq_len)},
115+
executorch::aten::ScalarType::Long)
116+
: ::executorch::extension::from_blob(
117+
&start_pos, {1}, executorch::aten::ScalarType::Long);
110118
auto prefill_result = module_->execute(
111119
kTextModelMethod, {cache_position_tensor, encoder_output});
112120
if (prefill_result.error() != ::executorch::runtime::Error::Ok) {

0 commit comments

Comments
 (0)