@@ -3045,6 +3045,7 @@ struct server_context {
30453045
30463046 // track if given slot can be batched with slots already in the batch
30473047 server_slot * slot_batched = nullptr ;
3048+ bool speculative_accepted = false ;
30483049
30493050 auto accept_special_token = [&](server_slot & slot, llama_token token) {
30503051 return params_base.special || slot.params .sampling .preserved_tokens .find (token) != slot.params .sampling .preserved_tokens .end ();
@@ -3056,6 +3057,97 @@ struct server_context {
30563057 continue ;
30573058 }
30583059
3060+ if (slot.state == SLOT_STATE_GENERATING && slot.is_processing () && slot.can_speculate ()) {
3061+ if (mctx) {
3062+ // we should never reach this, as speculative is automatically disabled if mmproj is loaded
3063+ GGML_ABORT (" not supported by multimodal" );
3064+ }
3065+
3066+ // determine the max draft that fits the current slot state
3067+ int n_draft_max = slot.params .speculative .n_max ;
3068+
3069+ // note: n_past is not yet increased for the `id` token sampled above
3070+ // also, need to leave space for 1 extra token to allow context shifts
3071+ n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
3072+
3073+ if (slot.n_remaining > 0 ) {
3074+ n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
3075+ }
3076+
3077+ SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
3078+
3079+ if (n_draft_max >= slot.params .speculative .n_min ) {
3080+ llama_token id = slot.sampled ;
3081+
3082+ struct common_speculative_params params_spec;
3083+ params_spec.n_draft = n_draft_max;
3084+ params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
3085+ params_spec.p_min = slot.params .speculative .p_min ;
3086+
3087+ const llama_tokens & cached_text_tokens = slot.cache_tokens .get_text_tokens ();
3088+ llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, cached_text_tokens, id);
3089+
3090+ // ignore small drafts
3091+ if (slot.params .speculative .n_min <= (int ) draft.size ()) {
3092+ // keep track of total number of drafted tokens tested
3093+ slot.n_draft_total += draft.size ();
3094+
3095+ // construct the speculation batch
3096+ common_batch_clear (slot.batch_spec );
3097+ common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
3098+
3099+ for (size_t i = 0 ; i < draft.size (); ++i) {
3100+ common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
3101+ }
3102+
3103+ SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
3104+
3105+ llama_decode (ctx, slot.batch_spec );
3106+
3107+ // the accepted tokens from the speculation
3108+ const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
3109+
3110+ slot.n_past += ids.size ();
3111+ slot.n_decoded += ids.size ();
3112+
3113+ // update how many tokens out of those tested were accepted
3114+ slot.n_draft_accepted += ids.size () - 1 ;
3115+
3116+ slot.cache_tokens .push_back (id);
3117+ slot.cache_tokens .insert ({ids.begin (), ids.end () - 1 });
3118+
3119+ llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 );
3120+
3121+ for (size_t i = 0 ; i < ids.size (); ++i) {
3122+ completion_token_output result;
3123+
3124+ result.tok = ids[i];
3125+ result.text_to_send = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3126+ result.prob = 1 .0f ; // set later
3127+
3128+ // TODO: set result.probs
3129+
3130+ if (!process_token (result, slot)) {
3131+ // release slot because of stop condition
3132+ slot.release ();
3133+ slot.print_timings ();
3134+ send_final_response (slot);
3135+ metrics.on_prediction (slot);
3136+ break ;
3137+ }
3138+ }
3139+
3140+ speculative_accepted = true ;
3141+ SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d\n " , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3142+ continue ;
3143+ }
3144+
3145+ SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
3146+ } else {
3147+ SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
3148+ }
3149+ }
3150+
30593151 // check if we can batch this slot with the previous one
30603152 if (!slot_batched) {
30613153 slot_batched = &slot;
@@ -3379,7 +3471,10 @@ struct server_context {
33793471 }
33803472
33813473 if (batch.n_tokens == 0 ) {
3382- SRV_WRN (" %s" , " no tokens to decode\n " );
3474+ if (!speculative_accepted) {
3475+ SRV_WRN (" %s" , " no tokens to decode\n " );
3476+ }
3477+
33833478 return ;
33843479 }
33853480
@@ -3520,108 +3615,6 @@ struct server_context {
35203615 continue ;
35213616 }
35223617 }
3523-
3524- // do speculative decoding
3525- for (auto & slot : slots) {
3526- if (!slot.is_processing () || !slot.can_speculate ()) {
3527- continue ;
3528- }
3529-
3530- if (slot.state != SLOT_STATE_GENERATING) {
3531- continue ;
3532- }
3533-
3534- if (mctx) {
3535- // we should never reach this, as speculative is automatically disabled if mmproj is loaded
3536- GGML_ABORT (" not supported by multimodal" );
3537- }
3538-
3539- // determine the max draft that fits the current slot state
3540- int n_draft_max = slot.params .speculative .n_max ;
3541-
3542- // note: n_past is not yet increased for the `id` token sampled above
3543- // also, need to leave space for 1 extra token to allow context shifts
3544- n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
3545-
3546- if (slot.n_remaining > 0 ) {
3547- n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
3548- }
3549-
3550- SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
3551-
3552- if (n_draft_max < slot.params .speculative .n_min ) {
3553- SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
3554-
3555- continue ;
3556- }
3557-
3558- llama_token id = slot.sampled ;
3559-
3560- struct common_speculative_params params_spec;
3561- params_spec.n_draft = n_draft_max;
3562- params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
3563- params_spec.p_min = slot.params .speculative .p_min ;
3564-
3565- const llama_tokens & cached_text_tokens = slot.cache_tokens .get_text_tokens ();
3566- llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, cached_text_tokens, id);
3567-
3568- // ignore small drafts
3569- if (slot.params .speculative .n_min > (int ) draft.size ()) {
3570- SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
3571-
3572- continue ;
3573- }
3574-
3575- // keep track of total number of drafted tokens tested
3576- slot.n_draft_total += draft.size ();
3577-
3578- // construct the speculation batch
3579- common_batch_clear (slot.batch_spec );
3580- common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
3581-
3582- for (size_t i = 0 ; i < draft.size (); ++i) {
3583- common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
3584- }
3585-
3586- SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
3587-
3588- llama_decode (ctx, slot.batch_spec );
3589-
3590- // the accepted tokens from the speculation
3591- const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
3592-
3593- slot.n_past += ids.size ();
3594- slot.n_decoded += ids.size ();
3595-
3596- // update how many tokens out of those tested were accepted
3597- slot.n_draft_accepted += ids.size () - 1 ;
3598-
3599- slot.cache_tokens .push_back (id);
3600- slot.cache_tokens .insert ({ids.begin (), ids.end () - 1 });
3601-
3602- llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 );
3603-
3604- for (size_t i = 0 ; i < ids.size (); ++i) {
3605- completion_token_output result;
3606-
3607- result.tok = ids[i];
3608- result.text_to_send = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3609- result.prob = 1 .0f ; // set later
3610-
3611- // TODO: set result.probs
3612-
3613- if (!process_token (result, slot)) {
3614- // release slot because of stop condition
3615- slot.release ();
3616- slot.print_timings ();
3617- send_final_response (slot);
3618- metrics.on_prediction (slot);
3619- break ;
3620- }
3621- }
3622-
3623- SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d\n " , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3624- }
36253618 }
36263619
36273620 SRV_DBG (" %s" , " run slots completed\n " );
0 commit comments