@@ -3039,6 +3039,7 @@ struct server_context {
30393039
30403040 // track if given slot can be batched with slots already in the batch
30413041 server_slot * slot_batched = nullptr ;
3042+ bool speculative_accepted = false ;
30423043
30433044 auto accept_special_token = [&](server_slot & slot, llama_token token) {
30443045 return params_base.special || slot.params .sampling .preserved_tokens .find (token) != slot.params .sampling .preserved_tokens .end ();
@@ -3050,6 +3051,97 @@ struct server_context {
30503051 continue ;
30513052 }
30523053
3054+ if (slot.state == SLOT_STATE_GENERATING && slot.is_processing () && slot.can_speculate ()) {
3055+ if (mctx) {
3056+ // we should never reach this, as speculative is automatically disabled if mmproj is loaded
3057+ GGML_ABORT (" not supported by multimodal" );
3058+ }
3059+
3060+ // determine the max draft that fits the current slot state
3061+ int n_draft_max = slot.params .speculative .n_max ;
3062+
3063+ // note: n_past is not yet increased for the `id` token sampled above
3064+ // also, need to leave space for 1 extra token to allow context shifts
3065+ n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
3066+
3067+ if (slot.n_remaining > 0 ) {
3068+ n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
3069+ }
3070+
3071+ SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
3072+
3073+ if (n_draft_max >= slot.params .speculative .n_min ) {
3074+ llama_token id = slot.sampled ;
3075+
3076+ struct common_speculative_params params_spec;
3077+ params_spec.n_draft = n_draft_max;
3078+ params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
3079+ params_spec.p_min = slot.params .speculative .p_min ;
3080+
3081+ const llama_tokens & cached_text_tokens = slot.cache_tokens .get_text_tokens ();
3082+ llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, cached_text_tokens, id);
3083+
3084+ // ignore small drafts
3085+ if (slot.params .speculative .n_min <= (int ) draft.size ()) {
3086+ // keep track of total number of drafted tokens tested
3087+ slot.n_draft_total += draft.size ();
3088+
3089+ // construct the speculation batch
3090+ common_batch_clear (slot.batch_spec );
3091+ common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
3092+
3093+ for (size_t i = 0 ; i < draft.size (); ++i) {
3094+ common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
3095+ }
3096+
3097+ SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
3098+
3099+ llama_decode (ctx, slot.batch_spec );
3100+
3101+ // the accepted tokens from the speculation
3102+ const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
3103+
3104+ slot.n_past += ids.size ();
3105+ slot.n_decoded += ids.size ();
3106+
3107+ // update how many tokens out of those tested were accepted
3108+ slot.n_draft_accepted += ids.size () - 1 ;
3109+
3110+ slot.cache_tokens .push_back (id);
3111+ slot.cache_tokens .insert ({ids.begin (), ids.end () - 1 });
3112+
3113+ llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 );
3114+
3115+ for (size_t i = 0 ; i < ids.size (); ++i) {
3116+ completion_token_output result;
3117+
3118+ result.tok = ids[i];
3119+ result.text_to_send = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3120+ result.prob = 1 .0f ; // set later
3121+
3122+ // TODO: set result.probs
3123+
3124+ if (!process_token (result, slot)) {
3125+ // release slot because of stop condition
3126+ slot.release ();
3127+ slot.print_timings ();
3128+ send_final_response (slot);
3129+ metrics.on_prediction (slot);
3130+ break ;
3131+ }
3132+ }
3133+
3134+ speculative_accepted = true ;
3135+ SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d\n " , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3136+ continue ;
3137+ }
3138+
3139+ SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
3140+ } else {
3141+ SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
3142+ }
3143+ }
3144+
30533145 // check if we can batch this slot with the previous one
30543146 if (!slot_batched) {
30553147 slot_batched = &slot;
@@ -3373,7 +3465,10 @@ struct server_context {
33733465 }
33743466
33753467 if (batch.n_tokens == 0 ) {
3376- SRV_WRN (" %s" , " no tokens to decode\n " );
3468+ if (!speculative_accepted) {
3469+ SRV_WRN (" %s" , " no tokens to decode\n " );
3470+ }
3471+
33773472 return ;
33783473 }
33793474
@@ -3514,108 +3609,6 @@ struct server_context {
35143609 continue ;
35153610 }
35163611 }
3517-
3518- // do speculative decoding
3519- for (auto & slot : slots) {
3520- if (!slot.is_processing () || !slot.can_speculate ()) {
3521- continue ;
3522- }
3523-
3524- if (slot.state != SLOT_STATE_GENERATING) {
3525- continue ;
3526- }
3527-
3528- if (mctx) {
3529- // we should never reach this, as speculative is automatically disabled if mmproj is loaded
3530- GGML_ABORT (" not supported by multimodal" );
3531- }
3532-
3533- // determine the max draft that fits the current slot state
3534- int n_draft_max = slot.params .speculative .n_max ;
3535-
3536- // note: n_past is not yet increased for the `id` token sampled above
3537- // also, need to leave space for 1 extra token to allow context shifts
3538- n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
3539-
3540- if (slot.n_remaining > 0 ) {
3541- n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
3542- }
3543-
3544- SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
3545-
3546- if (n_draft_max < slot.params .speculative .n_min ) {
3547- SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
3548-
3549- continue ;
3550- }
3551-
3552- llama_token id = slot.sampled ;
3553-
3554- struct common_speculative_params params_spec;
3555- params_spec.n_draft = n_draft_max;
3556- params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
3557- params_spec.p_min = slot.params .speculative .p_min ;
3558-
3559- const llama_tokens & cached_text_tokens = slot.cache_tokens .get_text_tokens ();
3560- llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, cached_text_tokens, id);
3561-
3562- // ignore small drafts
3563- if (slot.params .speculative .n_min > (int ) draft.size ()) {
3564- SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
3565-
3566- continue ;
3567- }
3568-
3569- // keep track of total number of drafted tokens tested
3570- slot.n_draft_total += draft.size ();
3571-
3572- // construct the speculation batch
3573- common_batch_clear (slot.batch_spec );
3574- common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
3575-
3576- for (size_t i = 0 ; i < draft.size (); ++i) {
3577- common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
3578- }
3579-
3580- SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
3581-
3582- llama_decode (ctx, slot.batch_spec );
3583-
3584- // the accepted tokens from the speculation
3585- const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
3586-
3587- slot.n_past += ids.size ();
3588- slot.n_decoded += ids.size ();
3589-
3590- // update how many tokens out of those tested were accepted
3591- slot.n_draft_accepted += ids.size () - 1 ;
3592-
3593- slot.cache_tokens .push_back (id);
3594- slot.cache_tokens .insert ({ids.begin (), ids.end () - 1 });
3595-
3596- llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 );
3597-
3598- for (size_t i = 0 ; i < ids.size (); ++i) {
3599- completion_token_output result;
3600-
3601- result.tok = ids[i];
3602- result.text_to_send = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3603- result.prob = 1 .0f ; // set later
3604-
3605- // TODO: set result.probs
3606-
3607- if (!process_token (result, slot)) {
3608- // release slot because of stop condition
3609- slot.release ();
3610- slot.print_timings ();
3611- send_final_response (slot);
3612- metrics.on_prediction (slot);
3613- break ;
3614- }
3615- }
3616-
3617- SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d\n " , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3618- }
36193612 }
36203613
36213614 SRV_DBG (" %s" , " run slots completed\n " );
0 commit comments