@@ -921,6 +921,8 @@ struct server_context {
921921 slot.params .speculative .p_min = json_value (data, " speculative.p_min" , defaults.speculative .p_min );
922922
923923 slot.params .speculative .n_min = std::min (slot.params .speculative .n_max , slot.params .speculative .n_min );
924+ slot.params .speculative .n_min = std::max (slot.params .speculative .n_min , 2 );
925+ slot.params .speculative .n_max = std::max (slot.params .speculative .n_max , 0 );
924926
925927 if (slot.params .sampling .dry_base < 1 .0f ) {
926928 slot.params .sampling .dry_base = defaults.sampling .dry_base ;
@@ -2322,17 +2324,38 @@ struct server_context {
23222324 continue ;
23232325 }
23242326
2327+ // determine the max draft that fits the current slot state
2328+ int n_draft_max = slot.params .speculative .n_max ;
2329+
2330+ // note: n_past is not yet increased for the `id` token sampled above
2331+ // also, need to leave space for 1 extra token to allow context shifts
2332+ n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
2333+
2334+ if (slot.n_remaining > 0 ) {
2335+ n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
2336+ }
2337+
2338+ SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
2339+
2340+ if (n_draft_max < slot.params .speculative .n_min ) {
2341+ SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
2342+
2343+ continue ;
2344+ }
2345+
23252346 llama_token id = slot.sampled ;
23262347
23272348 struct common_speculative_params params_spec;
2328- params_spec.n_draft = std::min (slot. params . speculative . n_max , slot. n_ctx - slot. n_past - 1 ) ;
2349+ params_spec.n_draft = n_draft_max ;
23292350 params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
23302351 params_spec.p_min = slot.params .speculative .p_min ;
23312352
23322353 llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
23332354
23342355 // ignore small drafts
23352356 if (slot.params .speculative .n_min > (int ) draft.size ()) {
2357+ SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
2358+
23362359 continue ;
23372360 }
23382361
@@ -2344,6 +2367,8 @@ struct server_context {
23442367 common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
23452368 }
23462369
2370+ SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
2371+
23472372 llama_decode (ctx, slot.batch_spec );
23482373
23492374 // the accepted tokens from the speculation
@@ -2372,7 +2397,7 @@ struct server_context {
23722397 }
23732398 }
23742399
2375- SRV_DBG ( " accepted %d/%d draft tokens\n " , (int ) ids.size () - 1 , (int ) draft.size ());
2400+ SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d \n " , (int ) ids.size () - 1 , (int ) draft.size (), slot. n_past );
23762401 }
23772402 }
23782403
0 commit comments