@@ -193,21 +193,15 @@ struct server_slot {
193193
194194 llama_token sampled;
195195
196- int32_t ga_i = 0 ; // group-attention state
197- int32_t ga_n = 1 ; // group-attention factor
198- int32_t ga_w = 512 ; // group-attention width
199-
200- int32_t n_past_se = 0 ; // self-extend
201-
202196 // stats
203- size_t n_sent_text = 0 ; // number of sent text character
197+ size_t n_sent_text = 0 ; // number of sent text character
204198 size_t n_sent_token_probs = 0 ;
205199
206200 int64_t t_start_process_prompt;
207201 int64_t t_start_generation;
208202
209203 double t_prompt_processing; // ms
210- double t_token_generation; // ms
204+ double t_token_generation; // ms
211205
212206 std::function<void (int )> callback_on_release;
213207
@@ -225,8 +219,6 @@ struct server_slot {
225219 n_sent_text = 0 ;
226220 n_sent_token_probs = 0 ;
227221 cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
228- ga_i = 0 ;
229- n_past_se = 0 ;
230222
231223 generated_token_probs.clear ();
232224 }
@@ -705,22 +697,6 @@ struct server_context {
705697
706698 SLT_INF (slot, " new slot n_ctx_slot = %d\n " , slot.n_ctx );
707699
708- const int ga_n = params.grp_attn_n ;
709- const int ga_w = params.grp_attn_w ;
710-
711- if (ga_n != 1 ) {
712- GGML_ASSERT (ga_n > 0 && " ga_n must be positive" ); // NOLINT
713- GGML_ASSERT (ga_w % ga_n == 0 && " ga_w must be a multiple of ga_n" ); // NOLINT
714- // GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
715- // GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
716-
717- SLT_INF (slot, " slot self-extend: ga_n = %d, ga_w = %d\n " , ga_n, ga_w);
718- }
719-
720- slot.ga_i = 0 ;
721- slot.ga_n = ga_n;
722- slot.ga_w = ga_w;
723-
724700 slot.sparams = params.sparams ;
725701
726702 slot.callback_on_release = [this ](int ) {
@@ -906,19 +882,14 @@ struct server_context {
906882 }
907883 if (data.contains (" json_schema" ) && !data.contains (" grammar" )) {
908884 try {
909- auto schema = json_value (data, " json_schema" , json::object ());
910- slot.sparams .grammar = json_schema_to_grammar (schema);
885+ auto schema = json_value (data, " json_schema" , json::object ());
886+ slot.sparams .grammar = json_schema_to_grammar (schema);
911887 } catch (const std::exception & e) {
912888 send_error (task, std::string (" \" json_schema\" : " ) + e.what (), ERROR_TYPE_INVALID_REQUEST);
913889 return false ;
914890 }
915891 } else {
916- slot.sparams .grammar = json_value (data, " grammar" , default_sparams.grammar );
917- }
918-
919- if (slot.params .cache_prompt && slot.ga_n != 1 ) {
920- slot.params .cache_prompt = false ;
921- SLT_WRN (slot, " %s" , " group-attention is not supported with prompt caching. disabling cache\n " );
892+ slot.sparams .grammar = json_value (data, " grammar" , default_sparams.grammar );
922893 }
923894
924895 if (slot.n_predict > 0 && slot.params .n_predict > slot.n_predict ) {
@@ -1131,12 +1102,13 @@ struct server_context {
11311102 }
11321103
11331104 // if context shift is disabled, we stop when it reaches the context limit
1134- if (slot.n_decoded >= slot.n_ctx ) {
1105+ if (slot.n_past >= slot.n_ctx ) {
11351106 slot.truncated = true ;
11361107 slot.stopped_limit = true ;
11371108 slot.has_next_token = false ;
11381109
1139- SLT_DBG (slot, " stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n " , slot.n_decoded , slot.n_ctx );
1110+ SLT_DBG (slot, " stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n " ,
1111+ slot.n_decoded , slot.n_prompt_tokens , slot.n_past , slot.n_ctx );
11401112 }
11411113
11421114 if (llama_token_is_eog (model, result.tok )) {
@@ -1148,13 +1120,13 @@ struct server_context {
11481120
11491121 const auto n_ctx_train = llama_n_ctx_train (model);
11501122
1151- if (slot.params .n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && slot. n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1123+ if (slot.params .n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
11521124 slot.truncated = true ;
11531125 slot.stopped_limit = true ;
11541126 slot.has_next_token = false ; // stop prediction
11551127
11561128 SLT_WRN (slot,
1157- " n_predict (%d) is not set and self-context extend is disabled . "
1129+ " n_predict (%d) is set for infinite generation . "
11581130 " Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n " ,
11591131 slot.params .n_predict , n_ctx_train);
11601132 }
@@ -1826,38 +1798,36 @@ struct server_context {
18261798 // apply context-shift if needed
18271799 // TODO: simplify and improve
18281800 for (server_slot & slot : slots) {
1829- if (slot.ga_n == 1 ) {
1830- if (slot.is_processing () && slot.n_past >= slot.n_ctx - 1 ) {
1831- if (!params.ctx_shift ) {
1832- // this check is redundant (for good)
1833- // we should never get here, because generation should already stopped in process_token()
1834- slot.release ();
1835- send_error (slot, " context shift is disabled" , ERROR_TYPE_SERVER);
1836- continue ;
1837- }
1838-
1839- // Shift context
1840- const int n_keep = slot.params .n_keep + add_bos_token;
1841- const int n_left = slot.n_past - n_keep;
1842- const int n_discard = slot.params .n_discard ? slot.params .n_discard : (n_left / 2 );
1801+ if (slot.is_processing () && slot.n_past + 1 >= slot.n_ctx ) {
1802+ if (!params.ctx_shift ) {
1803+ // this check is redundant (for good)
1804+ // we should never get here, because generation should already stopped in process_token()
1805+ slot.release ();
1806+ send_error (slot, " context shift is disabled" , ERROR_TYPE_SERVER);
1807+ continue ;
1808+ }
18431809
1844- SLT_WRN (slot, " slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n " , n_keep, n_left, n_discard);
1810+ // Shift context
1811+ const int n_keep = slot.params .n_keep + add_bos_token;
1812+ const int n_left = slot.n_past - n_keep;
1813+ const int n_discard = slot.params .n_discard ? slot.params .n_discard : (n_left / 2 );
18451814
1846- llama_kv_cache_seq_rm (ctx, slot.id + 1 , n_keep , n_keep + n_discard);
1847- llama_kv_cache_seq_add (ctx, slot.id + 1 , n_keep + n_discard, slot.n_past , -n_discard);
1815+ SLT_WRN (slot, " slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n " , n_keep, n_left, n_discard);
18481816
1849- if (slot.params .cache_prompt ) {
1850- for (size_t i = n_keep + n_discard; i < slot.cache_tokens .size (); i++) {
1851- slot.cache_tokens [i - n_discard] = slot.cache_tokens [i];
1852- }
1817+ llama_kv_cache_seq_rm (ctx, slot.id + 1 , n_keep , n_keep + n_discard);
1818+ llama_kv_cache_seq_add (ctx, slot.id + 1 , n_keep + n_discard, slot.n_past , -n_discard);
18531819
1854- slot.cache_tokens .resize (slot.cache_tokens .size () - n_discard);
1820+ if (slot.params .cache_prompt ) {
1821+ for (size_t i = n_keep + n_discard; i < slot.cache_tokens .size (); i++) {
1822+ slot.cache_tokens [i - n_discard] = slot.cache_tokens [i];
18551823 }
18561824
1857- slot.n_past -= n_discard;
1858-
1859- slot.truncated = true ;
1825+ slot.cache_tokens .resize (slot.cache_tokens .size () - n_discard);
18601826 }
1827+
1828+ slot.n_past -= n_discard;
1829+
1830+ slot.truncated = true ;
18611831 }
18621832 }
18631833
@@ -1872,9 +1842,7 @@ struct server_context {
18721842
18731843 slot.i_batch = batch.n_tokens ;
18741844
1875- const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past ;
1876-
1877- common_batch_add (batch, slot.sampled , slot_npast, { slot.id + 1 }, true );
1845+ common_batch_add (batch, slot.sampled , slot.n_past , { slot.id + 1 }, true );
18781846
18791847 slot.n_past += 1 ;
18801848
@@ -1993,6 +1961,8 @@ struct server_context {
19931961 } else {
19941962 if (!params.ctx_shift ) {
19951963 // if context shift is disabled, we make sure prompt size is smaller than KV size
1964+ // TODO: there should be a separate parameter that control prompt truncation
1965+ // context shift should be applied only during the generation phase
19961966 if (slot.n_prompt_tokens >= slot.n_ctx ) {
19971967 slot.release ();
19981968 send_error (slot, " the request exceeds the available context size. try increasing the context size or enable context shift" , ERROR_TYPE_INVALID_REQUEST);
@@ -2005,7 +1975,7 @@ struct server_context {
20051975 slot.params .n_keep = std::min (slot.n_ctx - 4 , slot.params .n_keep );
20061976
20071977 // if input prompt is too big, truncate it (if group attention self-extend is disabled)
2008- if (slot.ga_n == 1 && slot. n_prompt_tokens >= slot.n_ctx ) {
1978+ if (slot.n_prompt_tokens >= slot.n_ctx ) {
20091979 const int n_left = slot.n_ctx - slot.params .n_keep ;
20101980
20111981 const int n_block_size = n_left / 2 ;
@@ -2032,12 +2002,7 @@ struct server_context {
20322002
20332003 common_sampler_reset (slot.smpl );
20342004
2035- if (!slot.params .cache_prompt ) {
2036- slot.n_past_se = 0 ;
2037- slot.ga_i = 0 ;
2038- } else {
2039- GGML_ASSERT (slot.ga_n == 1 );
2040-
2005+ if (slot.params .cache_prompt ) {
20412006 // reuse any previously computed tokens that are common with the new prompt
20422007 slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
20432008
@@ -2053,9 +2018,6 @@ struct server_context {
20532018 SLT_WRN (slot, " need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n " , slot.n_past , slot.n_prompt_tokens );
20542019
20552020 slot.n_past --;
2056- if (slot.ga_i > 0 ) {
2057- slot.n_past_se --;
2058- }
20592021 }
20602022
20612023 slot.n_prompt_tokens_processed = 0 ;
@@ -2081,52 +2043,31 @@ struct server_context {
20812043 }
20822044
20832045 // keep only the common part
2084- int p0 = slot.n_past ;
2085-
2086- if (!llama_kv_cache_seq_rm (ctx, slot.id + 1 , p0, -1 )) {
2046+ if (!llama_kv_cache_seq_rm (ctx, slot.id + 1 , slot.n_past , -1 )) {
20872047 // could not partially delete (likely using a non-Transformer model)
20882048 llama_kv_cache_seq_rm (ctx, slot.id + 1 , -1 , -1 );
20892049
2090- p0 = 0 ;
2091-
20922050 // there is no common part left
20932051 slot.n_past = 0 ;
2094- slot.n_past_se = 0 ;
2095- slot.ga_i = 0 ;
20962052
20972053 common_sampler_reset (slot.smpl );
20982054 }
20992055
2056+ SLT_INF (slot, " kv cache rm [%d, end)\n " , slot.n_past );
2057+
21002058 // remove the non-common part from the cache
21012059 slot.cache_tokens .resize (slot.n_past );
21022060
2103- SLT_INF (slot, " kv cache rm [%d, end)\n " , p0);
2104-
2105- int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past ;
2106-
2107- int32_t ga_i = slot.ga_i ;
2108- int32_t ga_n = slot.ga_n ;
2109- int32_t ga_w = slot.ga_w ;
2110-
21112061 // add prompt tokens for processing in the current batch
2112- // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
2113- for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past ) {
2114- if (slot.ga_n != 1 ) {
2115- while (slot_npast >= ga_i + ga_w) {
2116- const int bd = (ga_w/ga_n)*(ga_n - 1 );
2117- slot_npast -= bd;
2118- ga_i += ga_w/ga_n;
2119- }
2120- }
2121-
2122- common_batch_add (batch, prompt_tokens[slot.n_past ], slot_npast, { slot.id + 1 }, false );
2062+ while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
2063+ common_batch_add (batch, prompt_tokens[slot.n_past ], slot.n_past , { slot.id + 1 }, false );
21232064
21242065 if (slot.params .cache_prompt ) {
21252066 slot.cache_tokens .push_back (prompt_tokens[slot.n_past ]);
21262067 }
21272068
21282069 slot.n_prompt_tokens_processed ++;
2129- slot_npast ++;
2070+ slot. n_past ++;
21302071 }
21312072
21322073 SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , batch.n_tokens , (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
@@ -2167,34 +2108,6 @@ struct server_context {
21672108 for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
21682109 const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
21692110
2170- for (auto & slot : slots) {
2171- if (slot.ga_n != 1 ) {
2172- // context extension via Self-Extend
2173- // TODO: simplify and/or abstract this
2174- while (slot.n_past_se >= slot.ga_i + slot.ga_w ) {
2175- const int ib = (slot.ga_n * slot.ga_i ) / slot.ga_w ;
2176- const int bd = (slot.ga_w / slot.ga_n ) * (slot.ga_n - 1 );
2177- const int dd = (slot.ga_w / slot.ga_n ) - ib * bd - slot.ga_w ;
2178-
2179- SLT_DBG (slot, " shift: [%6d, %6d] + %6d -> [%6d, %6d]\n " , slot.ga_i , slot.n_past_se , ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
2180- SLT_DBG (slot, " div: [%6d, %6d] / %6d -> [%6d, %6d]\n " , slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w , slot.ga_n , (slot.ga_i + ib * bd) / slot.ga_n , (slot.ga_i + ib * bd + slot.ga_w ) / slot.ga_n );
2181- SLT_DBG (slot, " shift: [%6d, %6d] + %6d -> [%6d, %6d]\n " , slot.ga_i + ib * bd + slot.ga_w , slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
2182-
2183- llama_kv_cache_seq_add (ctx, slot.id + 1 , slot.ga_i , slot.n_past_se , ib * bd);
2184- llama_kv_cache_seq_div (ctx, slot.id + 1 , slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w , slot.ga_n );
2185- llama_kv_cache_seq_add (ctx, slot.id + 1 , slot.ga_i + ib * bd + slot.ga_w , slot.n_past_se + ib * bd, dd);
2186-
2187- slot.n_past_se -= bd;
2188-
2189- slot.ga_i += slot.ga_w / slot.ga_n ;
2190-
2191- SLT_DBG (slot, " \n n_past_old = %d, n_past = %d, ga_i = %d\n\n " , slot.n_past_se + bd, slot.n_past_se , slot.ga_i );
2192- }
2193-
2194- slot.n_past_se += n_tokens;
2195- }
2196- }
2197-
21982111 llama_batch batch_view = {
21992112 n_tokens,
22002113 batch.token + i,
0 commit comments