@@ -193,21 +193,15 @@ struct server_slot {
193193
194194 llama_token sampled;
195195
196- int32_t ga_i = 0 ; // group-attention state
197- int32_t ga_n = 1 ; // group-attention factor
198- int32_t ga_w = 512 ; // group-attention width
199-
200- int32_t n_past_se = 0 ; // self-extend
201-
202196 // stats
203- size_t n_sent_text = 0 ; // number of sent text character
197+ size_t n_sent_text = 0 ; // number of sent text character
204198 size_t n_sent_token_probs = 0 ;
205199
206200 int64_t t_start_process_prompt;
207201 int64_t t_start_generation;
208202
209203 double t_prompt_processing; // ms
210- double t_token_generation; // ms
204+ double t_token_generation; // ms
211205
212206 std::function<void (int )> callback_on_release;
213207
@@ -225,8 +219,6 @@ struct server_slot {
225219 n_sent_text = 0 ;
226220 n_sent_token_probs = 0 ;
227221 cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
228- ga_i = 0 ;
229- n_past_se = 0 ;
230222
231223 generated_token_probs.clear ();
232224 }
@@ -705,22 +697,6 @@ struct server_context {
705697
706698 SLT_INF (slot, " new slot n_ctx_slot = %d\n " , slot.n_ctx );
707699
708- const int ga_n = params.grp_attn_n ;
709- const int ga_w = params.grp_attn_w ;
710-
711- if (ga_n != 1 ) {
712- GGML_ASSERT (ga_n > 0 && " ga_n must be positive" ); // NOLINT
713- GGML_ASSERT (ga_w % ga_n == 0 && " ga_w must be a multiple of ga_n" ); // NOLINT
714- // GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
715- // GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
716-
717- SLT_INF (slot, " slot self-extend: ga_n = %d, ga_w = %d\n " , ga_n, ga_w);
718- }
719-
720- slot.ga_i = 0 ;
721- slot.ga_n = ga_n;
722- slot.ga_w = ga_w;
723-
724700 slot.sparams = params.sparams ;
725701
726702 slot.callback_on_release = [this ](int ) {
@@ -906,19 +882,14 @@ struct server_context {
906882 }
907883 if (data.contains (" json_schema" ) && !data.contains (" grammar" )) {
908884 try {
909- auto schema = json_value (data, " json_schema" , json::object ());
910- slot.sparams .grammar = json_schema_to_grammar (schema);
885+ auto schema = json_value (data, " json_schema" , json::object ());
886+ slot.sparams .grammar = json_schema_to_grammar (schema);
911887 } catch (const std::exception & e) {
912888 send_error (task, std::string (" \" json_schema\" : " ) + e.what (), ERROR_TYPE_INVALID_REQUEST);
913889 return false ;
914890 }
915891 } else {
916- slot.sparams .grammar = json_value (data, " grammar" , default_sparams.grammar );
917- }
918-
919- if (slot.params .cache_prompt && slot.ga_n != 1 ) {
920- slot.params .cache_prompt = false ;
921- SLT_WRN (slot, " %s" , " group-attention is not supported with prompt caching. disabling cache\n " );
892+ slot.sparams .grammar = json_value (data, " grammar" , default_sparams.grammar );
922893 }
923894
924895 if (slot.n_predict > 0 && slot.params .n_predict > slot.n_predict ) {
@@ -1148,13 +1119,13 @@ struct server_context {
11481119
11491120 const auto n_ctx_train = llama_n_ctx_train (model);
11501121
1151- if (slot.params .n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && slot. n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1122+ if (slot.params .n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
11521123 slot.truncated = true ;
11531124 slot.stopped_limit = true ;
11541125 slot.has_next_token = false ; // stop prediction
11551126
11561127 SLT_WRN (slot,
1157- " n_predict (%d) is not set and self-context extend is disabled . "
1128+ " n_predict (%d) is set for infinite generation . "
11581129 " Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n " ,
11591130 slot.params .n_predict , n_ctx_train);
11601131 }
@@ -1826,38 +1797,36 @@ struct server_context {
18261797 // apply context-shift if needed
18271798 // TODO: simplify and improve
18281799 for (server_slot & slot : slots) {
1829- if (slot.ga_n == 1 ) {
1830- if (slot.is_processing () && slot.n_past >= slot.n_ctx - 1 ) {
1831- if (!params.ctx_shift ) {
1832- // this check is redundant (for good)
1833- // we should never get here, because generation should already stopped in process_token()
1834- slot.release ();
1835- send_error (slot, " context shift is disabled" , ERROR_TYPE_SERVER);
1836- continue ;
1837- }
1838-
1839- // Shift context
1840- const int n_keep = slot.params .n_keep + add_bos_token;
1841- const int n_left = slot.n_past - n_keep;
1842- const int n_discard = slot.params .n_discard ? slot.params .n_discard : (n_left / 2 );
1800+ if (slot.is_processing () && slot.n_past >= slot.n_ctx - 1 ) {
1801+ if (!params.ctx_shift ) {
1802+ // this check is redundant (for good)
1803+ // we should never get here, because generation should already stopped in process_token()
1804+ slot.release ();
1805+ send_error (slot, " context shift is disabled" , ERROR_TYPE_SERVER);
1806+ continue ;
1807+ }
18431808
1844- SLT_WRN (slot, " slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n " , n_keep, n_left, n_discard);
1809+ // Shift context
1810+ const int n_keep = slot.params .n_keep + add_bos_token;
1811+ const int n_left = slot.n_past - n_keep;
1812+ const int n_discard = slot.params .n_discard ? slot.params .n_discard : (n_left / 2 );
18451813
1846- llama_kv_cache_seq_rm (ctx, slot.id + 1 , n_keep , n_keep + n_discard);
1847- llama_kv_cache_seq_add (ctx, slot.id + 1 , n_keep + n_discard, slot.n_past , -n_discard);
1814+ SLT_WRN (slot, " slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n " , n_keep, n_left, n_discard);
18481815
1849- if (slot.params .cache_prompt ) {
1850- for (size_t i = n_keep + n_discard; i < slot.cache_tokens .size (); i++) {
1851- slot.cache_tokens [i - n_discard] = slot.cache_tokens [i];
1852- }
1816+ llama_kv_cache_seq_rm (ctx, slot.id + 1 , n_keep , n_keep + n_discard);
1817+ llama_kv_cache_seq_add (ctx, slot.id + 1 , n_keep + n_discard, slot.n_past , -n_discard);
18531818
1854- slot.cache_tokens .resize (slot.cache_tokens .size () - n_discard);
1819+ if (slot.params .cache_prompt ) {
1820+ for (size_t i = n_keep + n_discard; i < slot.cache_tokens .size (); i++) {
1821+ slot.cache_tokens [i - n_discard] = slot.cache_tokens [i];
18551822 }
18561823
1857- slot.n_past -= n_discard;
1858-
1859- slot.truncated = true ;
1824+ slot.cache_tokens .resize (slot.cache_tokens .size () - n_discard);
18601825 }
1826+
1827+ slot.n_past -= n_discard;
1828+
1829+ slot.truncated = true ;
18611830 }
18621831 }
18631832
@@ -1872,9 +1841,7 @@ struct server_context {
18721841
18731842 slot.i_batch = batch.n_tokens ;
18741843
1875- const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past ;
1876-
1877- common_batch_add (batch, slot.sampled , slot_npast, { slot.id + 1 }, true );
1844+ common_batch_add (batch, slot.sampled , slot.n_past , { slot.id + 1 }, true );
18781845
18791846 slot.n_past += 1 ;
18801847
@@ -2005,7 +1972,7 @@ struct server_context {
20051972 slot.params .n_keep = std::min (slot.n_ctx - 4 , slot.params .n_keep );
20061973
20071974 // if input prompt is too big, truncate it (if group attention self-extend is disabled)
2008- if (slot.ga_n == 1 && slot. n_prompt_tokens >= slot.n_ctx ) {
1975+ if (slot.n_prompt_tokens >= slot.n_ctx ) {
20091976 const int n_left = slot.n_ctx - slot.params .n_keep ;
20101977
20111978 const int n_block_size = n_left / 2 ;
@@ -2032,12 +1999,7 @@ struct server_context {
20321999
20332000 common_sampler_reset (slot.smpl );
20342001
2035- if (!slot.params .cache_prompt ) {
2036- slot.n_past_se = 0 ;
2037- slot.ga_i = 0 ;
2038- } else {
2039- GGML_ASSERT (slot.ga_n == 1 );
2040-
2002+ if (slot.params .cache_prompt ) {
20412003 // reuse any previously computed tokens that are common with the new prompt
20422004 slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
20432005
@@ -2053,9 +2015,6 @@ struct server_context {
20532015 SLT_WRN (slot, " need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n " , slot.n_past , slot.n_prompt_tokens );
20542016
20552017 slot.n_past --;
2056- if (slot.ga_i > 0 ) {
2057- slot.n_past_se --;
2058- }
20592018 }
20602019
20612020 slot.n_prompt_tokens_processed = 0 ;
@@ -2081,52 +2040,31 @@ struct server_context {
20812040 }
20822041
20832042 // keep only the common part
2084- int p0 = slot.n_past ;
2085-
2086- if (!llama_kv_cache_seq_rm (ctx, slot.id + 1 , p0, -1 )) {
2043+ if (!llama_kv_cache_seq_rm (ctx, slot.id + 1 , slot.n_past , -1 )) {
20872044 // could not partially delete (likely using a non-Transformer model)
20882045 llama_kv_cache_seq_rm (ctx, slot.id + 1 , -1 , -1 );
20892046
2090- p0 = 0 ;
2091-
20922047 // there is no common part left
20932048 slot.n_past = 0 ;
2094- slot.n_past_se = 0 ;
2095- slot.ga_i = 0 ;
20962049
20972050 common_sampler_reset (slot.smpl );
20982051 }
20992052
2053+ SLT_INF (slot, " kv cache rm [%d, end)\n " , slot.n_past );
2054+
21002055 // remove the non-common part from the cache
21012056 slot.cache_tokens .resize (slot.n_past );
21022057
2103- SLT_INF (slot, " kv cache rm [%d, end)\n " , p0);
2104-
2105- int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past ;
2106-
2107- int32_t ga_i = slot.ga_i ;
2108- int32_t ga_n = slot.ga_n ;
2109- int32_t ga_w = slot.ga_w ;
2110-
21112058 // add prompt tokens for processing in the current batch
2112- // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
2113- for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past ) {
2114- if (slot.ga_n != 1 ) {
2115- while (slot_npast >= ga_i + ga_w) {
2116- const int bd = (ga_w/ga_n)*(ga_n - 1 );
2117- slot_npast -= bd;
2118- ga_i += ga_w/ga_n;
2119- }
2120- }
2121-
2122- common_batch_add (batch, prompt_tokens[slot.n_past ], slot_npast, { slot.id + 1 }, false );
2059+ while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
2060+ common_batch_add (batch, prompt_tokens[slot.n_past ], slot.n_past , { slot.id + 1 }, false );
21232061
21242062 if (slot.params .cache_prompt ) {
21252063 slot.cache_tokens .push_back (prompt_tokens[slot.n_past ]);
21262064 }
21272065
21282066 slot.n_prompt_tokens_processed ++;
2129- slot_npast ++;
2067+ slot. n_past ++;
21302068 }
21312069
21322070 SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , batch.n_tokens , (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
@@ -2167,34 +2105,6 @@ struct server_context {
21672105 for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
21682106 const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
21692107
2170- for (auto & slot : slots) {
2171- if (slot.ga_n != 1 ) {
2172- // context extension via Self-Extend
2173- // TODO: simplify and/or abstract this
2174- while (slot.n_past_se >= slot.ga_i + slot.ga_w ) {
2175- const int ib = (slot.ga_n * slot.ga_i ) / slot.ga_w ;
2176- const int bd = (slot.ga_w / slot.ga_n ) * (slot.ga_n - 1 );
2177- const int dd = (slot.ga_w / slot.ga_n ) - ib * bd - slot.ga_w ;
2178-
2179- SLT_DBG (slot, " shift: [%6d, %6d] + %6d -> [%6d, %6d]\n " , slot.ga_i , slot.n_past_se , ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
2180- SLT_DBG (slot, " div: [%6d, %6d] / %6d -> [%6d, %6d]\n " , slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w , slot.ga_n , (slot.ga_i + ib * bd) / slot.ga_n , (slot.ga_i + ib * bd + slot.ga_w ) / slot.ga_n );
2181- SLT_DBG (slot, " shift: [%6d, %6d] + %6d -> [%6d, %6d]\n " , slot.ga_i + ib * bd + slot.ga_w , slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
2182-
2183- llama_kv_cache_seq_add (ctx, slot.id + 1 , slot.ga_i , slot.n_past_se , ib * bd);
2184- llama_kv_cache_seq_div (ctx, slot.id + 1 , slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w , slot.ga_n );
2185- llama_kv_cache_seq_add (ctx, slot.id + 1 , slot.ga_i + ib * bd + slot.ga_w , slot.n_past_se + ib * bd, dd);
2186-
2187- slot.n_past_se -= bd;
2188-
2189- slot.ga_i += slot.ga_w / slot.ga_n ;
2190-
2191- SLT_DBG (slot, " \n n_past_old = %d, n_past = %d, ga_i = %d\n\n " , slot.n_past_se + bd, slot.n_past_se , slot.ga_i );
2192- }
2193-
2194- slot.n_past_se += n_tokens;
2195- }
2196- }
2197-
21982108 llama_batch batch_view = {
21992109 n_tokens,
22002110 batch.token + i,
0 commit comments