@@ -434,6 +434,7 @@ class StaticAttentionIOManager {
434434 std::vector<size_t > k_cache_output_indices;
435435 std::vector<size_t > v_cache_input_indices;
436436 std::vector<size_t > v_cache_output_indices;
437+ size_t max_context_len{};
437438 RopeT* rope_freqs_cos;
438439 RopeT* rope_freqs_sin;
439440 StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
@@ -604,6 +605,10 @@ class StaticAttentionIOManager {
604605 size_t batch_len = 0 ;
605606 for (size_t i = 0 ; i < tokens.size (); i += input_len) {
606607 batch_len = std::min (input_len, tokens.size () - i);
608+ if (input_pos_ + batch_len > config_.max_context_len ) {
609+ ET_LOG (Error, " Maximum context size reached, stopping prefill." );
610+ return input_len - 1 ;
611+ }
607612 std::copy (&tokens[i], &tokens[i + batch_len], input_buffer.begin ());
608613 prepare (method);
609614 ET_CHECK (method.execute () == executorch::runtime::Error::Ok);
@@ -646,6 +651,10 @@ class StaticAttentionIOManager {
646651
647652 while (true ) {
648653 input_buffer[0 ] = prev_tok;
654+ if (input_pos_ + 1 > config_.max_context_len ) {
655+ ET_LOG (Error, " Maximum context size reached, stopping decode." );
656+ break ;
657+ }
649658 prepare (method);
650659 ET_CHECK (method.execute () == executorch::runtime::Error::Ok);
651660 update (
@@ -730,6 +739,11 @@ class StaticAttentionIOManager {
730739 }
731740
732741 // Setup input pointers and RoPE frequencies.
742+ if (input_pos_ + ngram_size > config_.max_context_len ) {
743+ ET_LOG (
744+ Error, " Maximum context size reached, stopping lookahead decode." );
745+ break ;
746+ }
733747 prepare (
734748 method,
735749 executorch::runtime::Span (pos_offsets.data (), pos_offsets.size ()));
0 commit comments