@@ -3038,6 +3038,7 @@ struct server_context {
30383038
30393039 // track if given slot can be batched with slots already in the batch
30403040 server_slot * slot_batched = nullptr ;
3041+ bool speculative_accepted = false ;
30413042
30423043 auto accept_special_token = [&](server_slot & slot, llama_token token) {
30433044 return params_base.special || slot.params .sampling .preserved_tokens .find (token) != slot.params .sampling .preserved_tokens .end ();
@@ -3049,6 +3050,97 @@ struct server_context {
30493050 continue ;
30503051 }
30513052
3053+ if (slot.state == SLOT_STATE_GENERATING && slot.is_processing () && slot.can_speculate ()) {
3054+ if (mctx) {
3055+ // we should never reach this, as speculative is automatically disabled if mmproj is loaded
3056+ GGML_ABORT (" not supported by multimodal" );
3057+ }
3058+
3059+ // determine the max draft that fits the current slot state
3060+ int n_draft_max = slot.params .speculative .n_max ;
3061+
3062+ // note: n_past is not yet increased for the `id` token sampled above
3063+ // also, need to leave space for 1 extra token to allow context shifts
3064+ n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
3065+
3066+ if (slot.n_remaining > 0 ) {
3067+ n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
3068+ }
3069+
3070+ SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
3071+
3072+ if (n_draft_max >= slot.params .speculative .n_min ) {
3073+ llama_token id = slot.sampled ;
3074+
3075+ struct common_speculative_params params_spec;
3076+ params_spec.n_draft = n_draft_max;
3077+ params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
3078+ params_spec.p_min = slot.params .speculative .p_min ;
3079+
3080+ const llama_tokens & cached_text_tokens = slot.cache_tokens .get_text_tokens ();
3081+ llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, cached_text_tokens, id);
3082+
3083+ // ignore small drafts
3084+ if (slot.params .speculative .n_min <= (int ) draft.size ()) {
3085+ // keep track of total number of drafted tokens tested
3086+ slot.n_draft_total += draft.size ();
3087+
3088+ // construct the speculation batch
3089+ common_batch_clear (slot.batch_spec );
3090+ common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
3091+
3092+ for (size_t i = 0 ; i < draft.size (); ++i) {
3093+ common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
3094+ }
3095+
3096+ SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
3097+
3098+ llama_decode (ctx, slot.batch_spec );
3099+
3100+ // the accepted tokens from the speculation
3101+ const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
3102+
3103+ slot.n_past += ids.size ();
3104+ slot.n_decoded += ids.size ();
3105+
3106+ // update how many tokens out of those tested were accepted
3107+ slot.n_draft_accepted += ids.size () - 1 ;
3108+
3109+ slot.cache_tokens .push_back (id);
3110+ slot.cache_tokens .insert ({ids.begin (), ids.end () - 1 });
3111+
3112+ llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 );
3113+
3114+ for (size_t i = 0 ; i < ids.size (); ++i) {
3115+ completion_token_output result;
3116+
3117+ result.tok = ids[i];
3118+ result.text_to_send = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3119+ result.prob = 1 .0f ; // set later
3120+
3121+ // TODO: set result.probs
3122+
3123+ if (!process_token (result, slot)) {
3124+ // release slot because of stop condition
3125+ slot.release ();
3126+ slot.print_timings ();
3127+ send_final_response (slot);
3128+ metrics.on_prediction (slot);
3129+ break ;
3130+ }
3131+ }
3132+
3133+ speculative_accepted = true ;
3134+ SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d\n " , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3135+ continue ;
3136+ }
3137+
3138+ SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
3139+ } else {
3140+ SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
3141+ }
3142+ }
3143+
30523144 // check if we can batch this slot with the previous one
30533145 if (!slot_batched) {
30543146 slot_batched = &slot;
@@ -3372,7 +3464,10 @@ struct server_context {
33723464 }
33733465
33743466 if (batch.n_tokens == 0 ) {
3375- SRV_WRN (" %s" , " no tokens to decode\n " );
3467+ if (!speculative_accepted) {
3468+ SRV_WRN (" %s" , " no tokens to decode\n " );
3469+ }
3470+
33763471 return ;
33773472 }
33783473
@@ -3513,108 +3608,6 @@ struct server_context {
35133608 continue ;
35143609 }
35153610 }
3516-
3517- // do speculative decoding
3518- for (auto & slot : slots) {
3519- if (!slot.is_processing () || !slot.can_speculate ()) {
3520- continue ;
3521- }
3522-
3523- if (slot.state != SLOT_STATE_GENERATING) {
3524- continue ;
3525- }
3526-
3527- if (mctx) {
3528- // we should never reach this, as speculative is automatically disabled if mmproj is loaded
3529- GGML_ABORT (" not supported by multimodal" );
3530- }
3531-
3532- // determine the max draft that fits the current slot state
3533- int n_draft_max = slot.params .speculative .n_max ;
3534-
3535- // note: n_past is not yet increased for the `id` token sampled above
3536- // also, need to leave space for 1 extra token to allow context shifts
3537- n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
3538-
3539- if (slot.n_remaining > 0 ) {
3540- n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
3541- }
3542-
3543- SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
3544-
3545- if (n_draft_max < slot.params .speculative .n_min ) {
3546- SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
3547-
3548- continue ;
3549- }
3550-
3551- llama_token id = slot.sampled ;
3552-
3553- struct common_speculative_params params_spec;
3554- params_spec.n_draft = n_draft_max;
3555- params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
3556- params_spec.p_min = slot.params .speculative .p_min ;
3557-
3558- const llama_tokens & cached_text_tokens = slot.cache_tokens .get_text_tokens ();
3559- llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, cached_text_tokens, id);
3560-
3561- // ignore small drafts
3562- if (slot.params .speculative .n_min > (int ) draft.size ()) {
3563- SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
3564-
3565- continue ;
3566- }
3567-
3568- // keep track of total number of drafted tokens tested
3569- slot.n_draft_total += draft.size ();
3570-
3571- // construct the speculation batch
3572- common_batch_clear (slot.batch_spec );
3573- common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
3574-
3575- for (size_t i = 0 ; i < draft.size (); ++i) {
3576- common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
3577- }
3578-
3579- SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
3580-
3581- llama_decode (ctx, slot.batch_spec );
3582-
3583- // the accepted tokens from the speculation
3584- const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
3585-
3586- slot.n_past += ids.size ();
3587- slot.n_decoded += ids.size ();
3588-
3589- // update how many tokens out of those tested were accepted
3590- slot.n_draft_accepted += ids.size () - 1 ;
3591-
3592- slot.cache_tokens .push_back (id);
3593- slot.cache_tokens .insert ({ids.begin (), ids.end () - 1 });
3594-
3595- llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 );
3596-
3597- for (size_t i = 0 ; i < ids.size (); ++i) {
3598- completion_token_output result;
3599-
3600- result.tok = ids[i];
3601- result.text_to_send = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3602- result.prob = 1 .0f ; // set later
3603-
3604- // TODO: set result.probs
3605-
3606- if (!process_token (result, slot)) {
3607- // release slot because of stop condition
3608- slot.release ();
3609- slot.print_timings ();
3610- send_final_response (slot);
3611- metrics.on_prediction (slot);
3612- break ;
3613- }
3614- }
3615-
3616- SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d\n " , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3617- }
36183611 }
36193612
36203613 SRV_DBG (" %s" , " run slots completed\n " );
0 commit comments