@@ -921,6 +921,8 @@ struct server_context {
921921        slot.params .speculative .p_min  = json_value (data, " speculative.p_min"  , defaults.speculative .p_min );
922922
923923        slot.params .speculative .n_min  = std::min (slot.params .speculative .n_max , slot.params .speculative .n_min );
924+         slot.params .speculative .n_min  = std::max (slot.params .speculative .n_min , 2 );
925+         slot.params .speculative .n_max  = std::max (slot.params .speculative .n_max , 0 );
924926
925927        if  (slot.params .sampling .dry_base  < 1 .0f ) {
926928           slot.params .sampling .dry_base  = defaults.sampling .dry_base ;
@@ -2322,17 +2324,38 @@ struct server_context {
23222324                    continue ;
23232325                }
23242326
2327+                 //  determine the max draft that fits the current slot state
2328+                 int  n_draft_max = slot.params .speculative .n_max ;
2329+ 
2330+                 //  note: n_past is not yet increased for the `id` token sampled above
2331+                 //        also, need to leave space for 1 extra token to allow context shifts
2332+                 n_draft_max = std::min (n_draft_max, slot.n_ctx  - slot.n_past  - 2 );
2333+ 
2334+                 if  (slot.n_remaining  > 0 ) {
2335+                     n_draft_max = std::min (n_draft_max, slot.n_remaining  - 1 );
2336+                 }
2337+ 
2338+                 SLT_DBG (slot, " max possible draft: %d\n "  , n_draft_max);
2339+ 
2340+                 if  (n_draft_max < slot.params .speculative .n_min ) {
2341+                     SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n "  , n_draft_max, slot.params .speculative .n_min );
2342+ 
2343+                     continue ;
2344+                 }
2345+ 
23252346                llama_token id = slot.sampled ;
23262347
23272348                struct  common_speculative_params  params_spec;
2328-                 params_spec.n_draft    = slot. params . speculative . n_max ;
2349+                 params_spec.n_draft    = n_draft_max ;
23292350                params_spec.n_reuse    = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
23302351                params_spec.p_min      = slot.params .speculative .p_min ;
23312352
23322353                llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
23332354
23342355                //  ignore small drafts
23352356                if  (slot.params .speculative .n_min  > (int ) draft.size ()) {
2357+                     SLT_DBG (slot, " ignoring small draft: %d < %d\n "  , (int ) draft.size (), slot.params .speculative .n_min );
2358+ 
23362359                    continue ;
23372360                }
23382361
@@ -2344,6 +2367,8 @@ struct server_context {
23442367                    common_batch_add (slot.batch_spec , draft[i], slot.n_past  + 1  + i, { slot.id  }, true );
23452368                }
23462369
2370+                 SLT_DBG (slot, " decoding speculative batch, size = %d\n "  , slot.batch_spec .n_tokens );
2371+ 
23472372                llama_decode (ctx, slot.batch_spec );
23482373
23492374                //  the accepted tokens from the speculation
@@ -2372,7 +2397,7 @@ struct server_context {
23722397                    }
23732398                }
23742399
2375-                 SRV_DBG ( " accepted %d/%d draft tokens\n "  , (int ) ids.size () - 1 , (int ) draft.size ());
2400+                 SLT_DBG (slot,  " accepted %d/%d draft tokens, new n_past = %d \n "  , (int ) ids.size () - 1 , (int ) draft.size (), slot. n_past );
23762401            }
23772402        }
23782403
0 commit comments