@@ -3038,6 +3038,7 @@ struct server_context {
30383038
30393039 // track if given slot can be batched with slots already in the batch
30403040 server_slot * slot_batched = nullptr ;
3041+ bool speculative_accepted = false ;
30413042
30423043 auto accept_special_token = [&](server_slot & slot, llama_token token) {
30433044 return params_base.special || slot.params .sampling .preserved_tokens .find (token) != slot.params .sampling .preserved_tokens .end ();
@@ -3049,6 +3050,97 @@ struct server_context {
30493050 continue ;
30503051 }
30513052
3053+ if (slot.state == SLOT_STATE_GENERATING && slot.is_processing () && slot.can_speculate ()) {
3054+ if (mctx) {
3055+ // we should never reach this, as speculative is automatically disabled if mmproj is loaded
3056+ GGML_ABORT (" not supported by multimodal" );
3057+ }
3058+
3059+ // determine the max draft that fits the current slot state
3060+ int n_draft_max = slot.params .speculative .n_max ;
3061+
3062+ // note: n_past is not yet increased for the `id` token sampled above
3063+ // also, need to leave space for 1 extra token to allow context shifts
3064+ n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
3065+
3066+ if (slot.n_remaining > 0 ) {
3067+ n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
3068+ }
3069+
3070+ SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
3071+
3072+ if (n_draft_max >= slot.params .speculative .n_min ) {
3073+ llama_token id = slot.sampled ;
3074+
3075+ struct common_speculative_params params_spec;
3076+ params_spec.n_draft = n_draft_max;
3077+ params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
3078+ params_spec.p_min = slot.params .speculative .p_min ;
3079+
3080+ const llama_tokens & cached_text_tokens = slot.cache_tokens .get_text_tokens ();
3081+ llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, cached_text_tokens, id);
3082+
3083+ // ignore small drafts
3084+ if (slot.params .speculative .n_min <= (int ) draft.size ()) {
3085+ // keep track of total number of drafted tokens tested
3086+ slot.n_draft_total += draft.size ();
3087+
3088+ // construct the speculation batch
3089+ common_batch_clear (slot.batch_spec );
3090+ common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
3091+
3092+ for (size_t i = 0 ; i < draft.size (); ++i) {
3093+ common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
3094+ }
3095+
3096+ SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
3097+
3098+ llama_decode (ctx, slot.batch_spec );
3099+
3100+ // the accepted tokens from the speculation
3101+ const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
3102+
3103+ slot.n_past += ids.size ();
3104+ slot.n_decoded += ids.size ();
3105+
3106+ // update how many tokens out of those tested were accepted
3107+ slot.n_draft_accepted += ids.size () - 1 ;
3108+
3109+ slot.cache_tokens .push_back (id);
3110+ slot.cache_tokens .insert ({ids.begin (), ids.end () - 1 });
3111+
3112+ llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 );
3113+
3114+ for (size_t i = 0 ; i < ids.size (); ++i) {
3115+ completion_token_output result;
3116+
3117+ result.tok = ids[i];
3118+ result.text_to_send = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3119+ result.prob = 1 .0f ; // set later
3120+
3121+ // TODO: set result.probs
3122+
3123+ if (!process_token (result, slot)) {
3124+ // release slot because of stop condition
3125+ slot.release ();
3126+ slot.print_timings ();
3127+ send_final_response (slot);
3128+ metrics.on_prediction (slot);
3129+ break ;
3130+ }
3131+ }
3132+
3133+ speculative_accepted = true ;
3134+ SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d\n " , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3135+ continue ;
3136+ }
3137+
3138+ SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
3139+ } else {
3140+ SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
3141+ }
3142+ }
3143+
30523144 // check if we can batch this slot with the previous one
30533145 if (!slot_batched) {
30543146 slot_batched = &slot;
@@ -3372,7 +3464,10 @@ struct server_context {
33723464 }
33733465
33743466 if (batch.n_tokens == 0 ) {
3375- SRV_WRN (" %s" , " no tokens to decode\n " );
3467+ if (!speculative_accepted) {
3468+ SRV_WRN (" %s" , " no tokens to decode\n " );
3469+ }
3470+
33763471 return ;
33773472 }
33783473
@@ -3510,108 +3605,6 @@ struct server_context {
35103605 continue ;
35113606 }
35123607 }
3513-
3514- // do speculative decoding
3515- for (auto & slot : slots) {
3516- if (!slot.is_processing () || !slot.can_speculate ()) {
3517- continue ;
3518- }
3519-
3520- if (slot.state != SLOT_STATE_GENERATING) {
3521- continue ;
3522- }
3523-
3524- if (mctx) {
3525- // we should never reach this, as speculative is automatically disabled if mmproj is loaded
3526- GGML_ABORT (" not supported by multimodal" );
3527- }
3528-
3529- // determine the max draft that fits the current slot state
3530- int n_draft_max = slot.params .speculative .n_max ;
3531-
3532- // note: n_past is not yet increased for the `id` token sampled above
3533- // also, need to leave space for 1 extra token to allow context shifts
3534- n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
3535-
3536- if (slot.n_remaining > 0 ) {
3537- n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
3538- }
3539-
3540- SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
3541-
3542- if (n_draft_max < slot.params .speculative .n_min ) {
3543- SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
3544-
3545- continue ;
3546- }
3547-
3548- llama_token id = slot.sampled ;
3549-
3550- struct common_speculative_params params_spec;
3551- params_spec.n_draft = n_draft_max;
3552- params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
3553- params_spec.p_min = slot.params .speculative .p_min ;
3554-
3555- const llama_tokens & cached_text_tokens = slot.cache_tokens .get_text_tokens ();
3556- llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, cached_text_tokens, id);
3557-
3558- // ignore small drafts
3559- if (slot.params .speculative .n_min > (int ) draft.size ()) {
3560- SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
3561-
3562- continue ;
3563- }
3564-
3565- // keep track of total number of drafted tokens tested
3566- slot.n_draft_total += draft.size ();
3567-
3568- // construct the speculation batch
3569- common_batch_clear (slot.batch_spec );
3570- common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
3571-
3572- for (size_t i = 0 ; i < draft.size (); ++i) {
3573- common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
3574- }
3575-
3576- SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
3577-
3578- llama_decode (ctx, slot.batch_spec );
3579-
3580- // the accepted tokens from the speculation
3581- const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
3582-
3583- slot.n_past += ids.size ();
3584- slot.n_decoded += ids.size ();
3585-
3586- // update how many tokens out of those tested were accepted
3587- slot.n_draft_accepted += ids.size () - 1 ;
3588-
3589- slot.cache_tokens .push_back (id);
3590- slot.cache_tokens .insert ({ids.begin (), ids.end () - 1 });
3591-
3592- llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 );
3593-
3594- for (size_t i = 0 ; i < ids.size (); ++i) {
3595- completion_token_output result;
3596-
3597- result.tok = ids[i];
3598- result.text_to_send = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3599- result.prob = 1 .0f ; // set later
3600-
3601- // TODO: set result.probs
3602-
3603- if (!process_token (result, slot)) {
3604- // release slot because of stop condition
3605- slot.release ();
3606- slot.print_timings ();
3607- send_final_response (slot);
3608- metrics.on_prediction (slot);
3609- break ;
3610- }
3611- }
3612-
3613- SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d\n " , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3614- }
36153608 }
36163609
36173610 SRV_DBG (" %s" , " run slots completed\n " );
0 commit comments