@@ -3040,6 +3040,7 @@ struct server_context {
30403040
30413041 // track if given slot can be batched with slots already in the batch
30423042 server_slot * slot_batched = nullptr ;
3043+ bool speculative_accepted = false ;
30433044
30443045 auto accept_special_token = [&](server_slot & slot, llama_token token) {
30453046 return params_base.special || slot.params .sampling .preserved_tokens .find (token) != slot.params .sampling .preserved_tokens .end ();
@@ -3051,6 +3052,97 @@ struct server_context {
30513052 continue ;
30523053 }
30533054
3055+ if (slot.state == SLOT_STATE_GENERATING && slot.is_processing () && slot.can_speculate ()) {
3056+ if (mctx) {
3057+ // we should never reach this, as speculative is automatically disabled if mmproj is loaded
3058+ GGML_ABORT (" not supported by multimodal" );
3059+ }
3060+
3061+ // determine the max draft that fits the current slot state
3062+ int n_draft_max = slot.params .speculative .n_max ;
3063+
3064+ // note: n_past is not yet increased for the `id` token sampled above
3065+ // also, need to leave space for 1 extra token to allow context shifts
3066+ n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
3067+
3068+ if (slot.n_remaining > 0 ) {
3069+ n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
3070+ }
3071+
3072+ SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
3073+
3074+ if (n_draft_max >= slot.params .speculative .n_min ) {
3075+ llama_token id = slot.sampled ;
3076+
3077+ struct common_speculative_params params_spec;
3078+ params_spec.n_draft = n_draft_max;
3079+ params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
3080+ params_spec.p_min = slot.params .speculative .p_min ;
3081+
3082+ const llama_tokens & cached_text_tokens = slot.cache_tokens .get_text_tokens ();
3083+ llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, cached_text_tokens, id);
3084+
3085+ // ignore small drafts
3086+ if (slot.params .speculative .n_min <= (int ) draft.size ()) {
3087+ // keep track of total number of drafted tokens tested
3088+ slot.n_draft_total += draft.size ();
3089+
3090+ // construct the speculation batch
3091+ common_batch_clear (slot.batch_spec );
3092+ common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
3093+
3094+ for (size_t i = 0 ; i < draft.size (); ++i) {
3095+ common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
3096+ }
3097+
3098+ SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
3099+
3100+ llama_decode (ctx, slot.batch_spec );
3101+
3102+ // the accepted tokens from the speculation
3103+ const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
3104+
3105+ slot.n_past += ids.size ();
3106+ slot.n_decoded += ids.size ();
3107+
3108+ // update how many tokens out of those tested were accepted
3109+ slot.n_draft_accepted += ids.size () - 1 ;
3110+
3111+ slot.cache_tokens .push_back (id);
3112+ slot.cache_tokens .insert ({ids.begin (), ids.end () - 1 });
3113+
3114+ llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 );
3115+
3116+ for (size_t i = 0 ; i < ids.size (); ++i) {
3117+ completion_token_output result;
3118+
3119+ result.tok = ids[i];
3120+ result.text_to_send = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3121+ result.prob = 1 .0f ; // set later
3122+
3123+ // TODO: set result.probs
3124+
3125+ if (!process_token (result, slot)) {
3126+ // release slot because of stop condition
3127+ slot.release ();
3128+ slot.print_timings ();
3129+ send_final_response (slot);
3130+ metrics.on_prediction (slot);
3131+ break ;
3132+ }
3133+ }
3134+
3135+ speculative_accepted = true ;
3136+ SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d\n " , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3137+ continue ;
3138+ }
3139+
3140+ SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
3141+ } else {
3142+ SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
3143+ }
3144+ }
3145+
30543146 // check if we can batch this slot with the previous one
30553147 if (!slot_batched) {
30563148 slot_batched = &slot;
@@ -3374,7 +3466,10 @@ struct server_context {
33743466 }
33753467
33763468 if (batch.n_tokens == 0 ) {
3377- SRV_WRN (" %s" , " no tokens to decode\n " );
3469+ if (!speculative_accepted) {
3470+ SRV_WRN (" %s" , " no tokens to decode\n " );
3471+ }
3472+
33783473 return ;
33793474 }
33803475
@@ -3515,108 +3610,6 @@ struct server_context {
35153610 continue ;
35163611 }
35173612 }
3518-
3519- // do speculative decoding
3520- for (auto & slot : slots) {
3521- if (!slot.is_processing () || !slot.can_speculate ()) {
3522- continue ;
3523- }
3524-
3525- if (slot.state != SLOT_STATE_GENERATING) {
3526- continue ;
3527- }
3528-
3529- if (mctx) {
3530- // we should never reach this, as speculative is automatically disabled if mmproj is loaded
3531- GGML_ABORT (" not supported by multimodal" );
3532- }
3533-
3534- // determine the max draft that fits the current slot state
3535- int n_draft_max = slot.params .speculative .n_max ;
3536-
3537- // note: n_past is not yet increased for the `id` token sampled above
3538- // also, need to leave space for 1 extra token to allow context shifts
3539- n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
3540-
3541- if (slot.n_remaining > 0 ) {
3542- n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
3543- }
3544-
3545- SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
3546-
3547- if (n_draft_max < slot.params .speculative .n_min ) {
3548- SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
3549-
3550- continue ;
3551- }
3552-
3553- llama_token id = slot.sampled ;
3554-
3555- struct common_speculative_params params_spec;
3556- params_spec.n_draft = n_draft_max;
3557- params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
3558- params_spec.p_min = slot.params .speculative .p_min ;
3559-
3560- const llama_tokens & cached_text_tokens = slot.cache_tokens .get_text_tokens ();
3561- llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, cached_text_tokens, id);
3562-
3563- // ignore small drafts
3564- if (slot.params .speculative .n_min > (int ) draft.size ()) {
3565- SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
3566-
3567- continue ;
3568- }
3569-
3570- // keep track of total number of drafted tokens tested
3571- slot.n_draft_total += draft.size ();
3572-
3573- // construct the speculation batch
3574- common_batch_clear (slot.batch_spec );
3575- common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
3576-
3577- for (size_t i = 0 ; i < draft.size (); ++i) {
3578- common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
3579- }
3580-
3581- SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
3582-
3583- llama_decode (ctx, slot.batch_spec );
3584-
3585- // the accepted tokens from the speculation
3586- const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
3587-
3588- slot.n_past += ids.size ();
3589- slot.n_decoded += ids.size ();
3590-
3591- // update how many tokens out of those tested were accepted
3592- slot.n_draft_accepted += ids.size () - 1 ;
3593-
3594- slot.cache_tokens .push_back (id);
3595- slot.cache_tokens .insert ({ids.begin (), ids.end () - 1 });
3596-
3597- llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 );
3598-
3599- for (size_t i = 0 ; i < ids.size (); ++i) {
3600- completion_token_output result;
3601-
3602- result.tok = ids[i];
3603- result.text_to_send = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3604- result.prob = 1 .0f ; // set later
3605-
3606- // TODO: set result.probs
3607-
3608- if (!process_token (result, slot)) {
3609- // release slot because of stop condition
3610- slot.release ();
3611- slot.print_timings ();
3612- send_final_response (slot);
3613- metrics.on_prediction (slot);
3614- break ;
3615- }
3616- }
3617-
3618- SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d\n " , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3619- }
36203613 }
36213614
36223615 SRV_DBG (" %s" , " run slots completed\n " );
0 commit comments