Skip to content

Commit 98b1052

Browse files
authored
StaticAttentionIOManager: Fix out of bound errors on precomuted RoPE frequencies
Differential Revision: D83361153 Pull Request resolved: #14630
1 parent 411578a commit 98b1052

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ class StaticAttentionIOManager {
434434
std::vector<size_t> k_cache_output_indices;
435435
std::vector<size_t> v_cache_input_indices;
436436
std::vector<size_t> v_cache_output_indices;
437+
size_t max_context_len{};
437438
RopeT* rope_freqs_cos;
438439
RopeT* rope_freqs_sin;
439440
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
@@ -604,6 +605,10 @@ class StaticAttentionIOManager {
604605
size_t batch_len = 0;
605606
for (size_t i = 0; i < tokens.size(); i += input_len) {
606607
batch_len = std::min(input_len, tokens.size() - i);
608+
if (input_pos_ + batch_len > config_.max_context_len) {
609+
ET_LOG(Error, "Maximum context size reached, stopping prefill.");
610+
return input_len - 1;
611+
}
607612
std::copy(&tokens[i], &tokens[i + batch_len], input_buffer.begin());
608613
prepare(method);
609614
ET_CHECK(method.execute() == executorch::runtime::Error::Ok);
@@ -646,6 +651,10 @@ class StaticAttentionIOManager {
646651

647652
while (true) {
648653
input_buffer[0] = prev_tok;
654+
if (input_pos_ + 1 > config_.max_context_len) {
655+
ET_LOG(Error, "Maximum context size reached, stopping decode.");
656+
break;
657+
}
649658
prepare(method);
650659
ET_CHECK(method.execute() == executorch::runtime::Error::Ok);
651660
update(
@@ -730,6 +739,11 @@ class StaticAttentionIOManager {
730739
}
731740

732741
// Setup input pointers and RoPE frequencies.
742+
if (input_pos_ + ngram_size > config_.max_context_len) {
743+
ET_LOG(
744+
Error, "Maximum context size reached, stopping lookahead decode.");
745+
break;
746+
}
733747
prepare(
734748
method,
735749
executorch::runtime::Span(pos_offsets.data(), pos_offsets.size()));

0 commit comments

Comments
 (0)