Skip to content

Commit ad1116e

Browse files
committed
Address comments
1 parent 9bf819c commit ad1116e

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

extension/llm/runner/multimodal_prefiller.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,30 +91,38 @@ Result<uint64_t> MultimodalPrefiller::prefill(
9191
}
9292

9393
// 2. Run decoder model for prefill.
94-
// `cache_position` goes from start_pos to start_pos + encoder_output.size(1).
95-
// e.g. if start_pos = 2 and encoder_output.size(1) = 5,
96-
// cache_position_tensor should be [2, 3, 4, 5, 6].
94+
95+
// Get expected shape of cache position tensor, which should be the second
96+
// argument
9797
auto method_meta = ET_UNWRAP(module_->method_meta(kTextModelMethod));
98-
auto first_input_info = ET_UNWRAP(method_meta.input_tensor_meta(0));
99-
auto sizes = first_input_info.sizes();
100-
auto numel = sizes[0];
98+
auto second_input_info = ET_UNWRAP(method_meta.input_tensor_meta(1));
99+
auto second_input_sizes = second_input_info.sizes();
100+
auto numel = second_input_sizes[0];
101101

102102
int64_t seq_len = encoder_output.toTensor().size(1);
103103
if (seq_len == 0) {
104104
ET_LOG(Error, "The encoder returned an empty output.");
105105
return ::executorch::runtime::Error::InvalidState;
106106
}
107-
std::vector<int64_t> cache_positions(seq_len);
108-
for (int64_t i = 0; i < seq_len; ++i) {
109-
cache_positions[i] = start_pos + i;
107+
108+
executorch::extension::TensorPtr cache_position_tensor;
109+
if (numel > 1) {
110+
// `cache_position` goes from start_pos to start_pos +
111+
// encoder_output.size(1). e.g. if start_pos = 2 and encoder_output.size(1)
112+
// = 5, cache_position_tensor should be [2, 3, 4, 5, 6].
113+
std::vector<int64_t> cache_positions(seq_len);
114+
for (int64_t i = 0; i < seq_len; ++i) {
115+
cache_positions[i] = start_pos + i;
116+
}
117+
cache_position_tensor = ::executorch::extension::from_blob(
118+
cache_positions.data(),
119+
{static_cast<int>(seq_len)},
120+
executorch::aten::ScalarType::Long);
121+
} else {
122+
// Cache position is size 1.
123+
cache_position_tensor = ::executorch::extension::from_blob(
124+
&start_pos, {1}, executorch::aten::ScalarType::Long);
110125
}
111-
auto cache_position_tensor = (numel > 1)
112-
? ::executorch::extension::from_blob(
113-
cache_positions.data(),
114-
{static_cast<int>(seq_len)},
115-
executorch::aten::ScalarType::Long)
116-
: ::executorch::extension::from_blob(
117-
&start_pos, {1}, executorch::aten::ScalarType::Long);
118126
auto prefill_result = module_->execute(
119127
kTextModelMethod, {cache_position_tensor, encoder_output});
120128
if (prefill_result.error() != ::executorch::runtime::Error::Ok) {

0 commit comments

Comments
 (0)