@@ -434,6 +434,7 @@ class StaticAttentionIOManager {
434
434
std::vector<size_t > k_cache_output_indices;
435
435
std::vector<size_t > v_cache_input_indices;
436
436
std::vector<size_t > v_cache_output_indices;
437
+ size_t max_context_len{};
437
438
RopeT* rope_freqs_cos;
438
439
RopeT* rope_freqs_sin;
439
440
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
@@ -604,6 +605,10 @@ class StaticAttentionIOManager {
604
605
size_t batch_len = 0 ;
605
606
for (size_t i = 0 ; i < tokens.size (); i += input_len) {
606
607
batch_len = std::min (input_len, tokens.size () - i);
608
+ if (input_pos_ + batch_len > config_.max_context_len ) {
609
+ ET_LOG (Error, " Maximum context size reached, stopping prefill." );
610
+ return input_len - 1 ;
611
+ }
607
612
std::copy (&tokens[i], &tokens[i + batch_len], input_buffer.begin ());
608
613
prepare (method);
609
614
ET_CHECK (method.execute () == executorch::runtime::Error::Ok);
@@ -646,6 +651,10 @@ class StaticAttentionIOManager {
646
651
647
652
while (true ) {
648
653
input_buffer[0 ] = prev_tok;
654
+ if (input_pos_ + 1 > config_.max_context_len ) {
655
+ ET_LOG (Error, " Maximum context size reached, stopping decode." );
656
+ break ;
657
+ }
649
658
prepare (method);
650
659
ET_CHECK (method.execute () == executorch::runtime::Error::Ok);
651
660
update (
@@ -730,6 +739,11 @@ class StaticAttentionIOManager {
730
739
}
731
740
732
741
// Setup input pointers and RoPE frequencies.
742
+ if (input_pos_ + ngram_size > config_.max_context_len ) {
743
+ ET_LOG (
744
+ Error, " Maximum context size reached, stopping lookahead decode." );
745
+ break ;
746
+ }
733
747
prepare (
734
748
method,
735
749
executorch::runtime::Span (pos_offsets.data (), pos_offsets.size ()));
0 commit comments