@@ -489,8 +489,12 @@ struct result_timings {
489489    double  predicted_per_token_ms;
490490    double  predicted_per_second;
491491
492+     //  Optional speculative metrics - only included when > 0
493+     int32_t  draft_n = 0 ;
494+     int32_t  draft_n_accepted = 0 ;
495+ 
492496    json to_json () const  {
493-         return  {
497+         json base =  {
494498            {" prompt_n" 
495499            {" prompt_ms" 
496500            {" prompt_per_token_ms" 
@@ -501,6 +505,13 @@ struct result_timings {
501505            {" predicted_per_token_ms" 
502506            {" predicted_per_second" 
503507        };
508+ 
509+         if  (draft_n > 0 ) {
510+             base[" draft_n" 
511+             base[" draft_n_accepted" 
512+         }
513+ 
514+         return  base;
504515    }
505516};
506517
@@ -1299,6 +1310,10 @@ struct server_slot {
12991310
13001311    std::function<void (int )> callback_on_release;
13011312
1313+     //  Speculative decoding stats
1314+     int32_t  n_draft_total = 0 ;      //  Total draft tokens generated
1315+     int32_t  n_draft_accepted = 0 ;   //  Draft tokens actually accepted
1316+ 
13021317    void  reset () {
13031318        SLT_DBG (*this , " %s" " \n " 
13041319
@@ -1315,6 +1330,10 @@ struct server_slot {
13151330
13161331        generated_tokens.clear ();
13171332        generated_token_probs.clear ();
1333+ 
1334+         //  clear speculative decoding stats
1335+         n_draft_total = 0 ;
1336+         n_draft_accepted = 0 ;
13181337    }
13191338
13201339    bool  is_non_causal () const  {
@@ -1381,6 +1400,12 @@ struct server_slot {
13811400        timings.predicted_per_token_ms  = t_token_generation / n_decoded;
13821401        timings.predicted_per_second  = 1e3  / t_token_generation * n_decoded;
13831402
1403+         //  Add speculative metrics
1404+         if  (n_draft_total > 0 ) {
1405+             timings.draft_n  = n_draft_total;
1406+             timings.draft_n_accepted  = n_draft_accepted;
1407+         }
1408+ 
13841409        return  timings;
13851410    }
13861411
@@ -1428,6 +1453,15 @@ struct server_slot {
14281453                t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
14291454                t_token_generation, n_decoded, t_gen, n_gen_second,
14301455                t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
1456+ 
1457+         if  (n_draft_total > 0 ) {
1458+             const  float  draft_ratio = (float ) n_draft_accepted / n_draft_total;
1459+             SLT_INF (*this ,
1460+                     " \n " 
1461+                     " draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n " 
1462+                     draft_ratio, n_draft_accepted, n_draft_total
1463+             );
1464+         }
14311465    }
14321466
14331467    json to_json () const  {
@@ -3290,6 +3324,9 @@ struct server_context {
32903324
32913325                llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
32923326
3327+                 //  keep track of total number of tokens generated in the draft
3328+                 slot.n_draft_total  += draft.size ();
3329+ 
32933330                //  ignore small drafts
32943331                if  (slot.params .speculative .n_min  > (int ) draft.size ()) {
32953332                    SLT_DBG (slot, " ignoring small draft: %d < %d\n " int ) draft.size (), slot.params .speculative .n_min );
@@ -3315,6 +3352,9 @@ struct server_context {
33153352                slot.n_past     += ids.size ();
33163353                slot.n_decoded  += ids.size ();
33173354
3355+                 //  update how many tokens out of draft was accepted
3356+                 slot.n_draft_accepted  += ids.size () - 1 ;
3357+ 
33183358                slot.cache_tokens .push_back (id);
33193359                slot.cache_tokens .insert (slot.cache_tokens .end (), ids.begin (), ids.end () - 1 );
33203360
0 commit comments