@@ -197,6 +197,7 @@ struct server_task {
197197    //  used by SERVER_TASK_TYPE_INFERENCE
198198    slot_params  params;
199199    llama_tokens prompt_tokens;
200+     llama_tokens prediction_tokens;
200201    int  id_selected_slot = -1 ;
201202
202203    //  used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
@@ -604,6 +605,7 @@ struct server_task_result_cmpl_final : server_task_result {
604605    int32_t  n_decoded;
605606    int32_t  n_prompt_tokens;
606607    int32_t  n_tokens_cached;
608+     int32_t  n_lookup_used;
607609    bool  has_new_line;
608610    std::string stopping_word;
609611    stop_type stop = STOP_TYPE_NONE;
@@ -660,6 +662,7 @@ struct server_task_result_cmpl_final : server_task_result {
660662            {" stopping_word"  ,       stopping_word},
661663            {" tokens_cached"  ,       n_tokens_cached},
662664            {" timings"  ,             timings.to_json ()},
665+             {" prediction_tokens_accepted"  , n_lookup_used},
663666        };
664667        if  (!stream && !probs_output.empty ()) {
665668            res[" completion_probabilities"  ] = completion_token_output::probs_vector_to_json (probs_output, post_sampling_probs);
@@ -695,7 +698,10 @@ struct server_task_result_cmpl_final : server_task_result {
695698            {" usage"  , json {
696699                {" completion_tokens"  , n_decoded},
697700                {" prompt_tokens"  ,     n_prompt_tokens},
698-                 {" total_tokens"  ,      n_decoded + n_prompt_tokens}
701+                 {" total_tokens"  ,      n_decoded + n_prompt_tokens},
702+                 {" completion_tokens_details"  , json {
703+                     {" accepted_prediction_tokens"  , n_lookup_used },
704+                 }}
699705            }},
700706            {" id"  , oaicompat_cmpl_id}
701707        };
@@ -771,11 +777,14 @@ struct server_task_result_cmpl_final : server_task_result {
771777            {" usage"  , json {
772778                {" completion_tokens"  , n_decoded},
773779                {" prompt_tokens"  ,     n_prompt_tokens},
774-                 {" total_tokens"  ,      n_decoded + n_prompt_tokens}
780+                 {" total_tokens"  ,      n_decoded + n_prompt_tokens},
781+                 {" completion_tokens_details"  , json {
782+                     {" accepted_prediction_tokens"  , n_lookup_used },
783+                 }}
775784            }},
776785            {" id"  , oaicompat_cmpl_id}
777786        };
778- 
787+          
779788        //  extra fields for debugging purposes
780789        if  (verbose) {
781790            res[" __verbose"  ] = to_json_non_oaicompat ();
@@ -811,6 +820,9 @@ struct server_task_result_cmpl_final : server_task_result {
811820                {" completion_tokens"  , n_decoded},
812821                {" prompt_tokens"  ,     n_prompt_tokens},
813822                {" total_tokens"  ,      n_decoded + n_prompt_tokens},
823+                 {" completion_tokens_details"  , json {
824+                     {" accepted_prediction_tokens"  , n_lookup_used },
825+                 }}
814826            }},
815827        };
816828
@@ -1235,16 +1247,22 @@ struct server_slot {
12351247    int32_t  n_ctx       = 0 ;  //  context size per slot
12361248    int32_t  n_past      = 0 ;
12371249    int32_t  n_decoded   = 0 ;
1250+     int32_t  n_lookup_used        = 0 ;
12381251    int32_t  n_remaining = -1 ;
12391252    int32_t  i_batch     = -1 ;
12401253    int32_t  n_predict   = -1 ; //  TODO: disambiguate from params.n_predict
12411254
1255+     //  for "predicted outputs"
1256+     int32_t  lookup_n_adaptive    = 1 ;
1257+     int32_t  lookup_index         = 0 ;
1258+ 
12421259    //  n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
12431260    int32_t  n_prompt_tokens           = 0 ;
12441261    int32_t  n_prompt_tokens_processed = 0 ;
12451262
12461263    //  input prompt tokens
12471264    llama_tokens prompt_tokens;
1265+     llama_tokens prediction_tokens;
12481266
12491267    size_t  last_nl_pos = 0 ;
12501268
@@ -1912,9 +1930,8 @@ struct server_context {
19121930            slot.n_ctx  = n_ctx_slot;
19131931            slot.n_predict  = params_base.n_predict ;
19141932
1933+             slot.batch_spec  = llama_batch_init (params_base.speculative .n_max  + 1 , 0 , 1 );
19151934            if  (model_dft) {
1916-                 slot.batch_spec  = llama_batch_init (params_base.speculative .n_max  + 1 , 0 , 1 );
1917- 
19181935                slot.ctx_dft  = llama_init_from_model (model_dft, cparams_dft);
19191936                if  (slot.ctx_dft  == nullptr ) {
19201937                    SRV_ERR (" %s"  , " failed to create draft context\n "  );
@@ -2034,6 +2051,7 @@ struct server_context {
20342051        slot.task_type      = task.type ;
20352052        slot.params         = std::move (task.params );
20362053        slot.prompt_tokens  = std::move (task.prompt_tokens );
2054+         slot.prediction_tokens  = std::move (task.prediction_tokens );
20372055
20382056        if  (!are_lora_equal (task.params .lora , slot.lora )) {
20392057            //  if lora is changed, we cannot reuse cached tokens
@@ -2345,6 +2363,7 @@ struct server_context {
23452363        res->n_decoded            = slot.n_decoded ;
23462364        res->n_prompt_tokens      = slot.n_prompt_tokens ;
23472365        res->n_tokens_cached      = slot.n_past ;
2366+         res->n_lookup_used        = slot.n_lookup_used ;
23482367        res->has_new_line         = slot.has_new_line ;
23492368        res->stopping_word        = slot.stopping_word ;
23502369        res->stop                 = slot.stop ;
@@ -3217,6 +3236,137 @@ struct server_context {
32173236                }
32183237            }
32193238
3239+             //  apply "predicted outputs" i.e. user-specified speculation
3240+             //  using a simple lookup decoding method
3241+             for  (auto  & slot : slots) {
3242+                 //  don't use lookup if we are also using a draft model
3243+                 if  (slot.can_speculate () || !slot.is_processing () || slot.prediction_tokens .size () < 2 ) {
3244+                     continue ;
3245+                 }
3246+                 if  (slot.state  != SLOT_STATE_GENERATING) {
3247+                     continue ;
3248+                 }
3249+ 
3250+                 //  adaptive speculation window:
3251+                 //  increase window size every time all drafted tokens were accepted, 
3252+                 //  otherwise reset to zero
3253+                 auto  draft_start_pos = 1 ;
3254+                 bool  found = false ;
3255+                 //  first look for a match from the expected position
3256+                 SLT_DBG (slot, " Looking up prediction tokens at index %d/%d\n "  , (int ) slot.lookup_index , (int ) slot.prediction_tokens .size ());
3257+                 if  (slot.lookup_index  > 0  && 
3258+                     slot.lookup_index  < static_cast <int32_t >(slot.prediction_tokens .size ()) && 
3259+                     slot.prediction_tokens [slot.lookup_index -1 ] == slot.sampled ) {
3260+                     found = true ;
3261+                     draft_start_pos = slot.lookup_index ;
3262+                     //  TODO what is a good scaling law here?
3263+                     //  going for too large windows too fast will likely fail,
3264+                     //  but also too small windows in the beginning hurt perf
3265+                     slot.lookup_n_adaptive  = std::max (16 , slot.lookup_n_adaptive *2 );
3266+                 } else  {
3267+                     //  find first match in prediction_tokens
3268+                     slot.lookup_n_adaptive  = 1 ; //  default
3269+                     for  (; draft_start_pos < static_cast <int32_t >(slot.prediction_tokens .size ()); draft_start_pos++) {
3270+                         if  (slot.prediction_tokens [draft_start_pos-1 ] == slot.sampled ) {
3271+                             found = true ;
3272+                             break ;
3273+                         }
3274+                     }
3275+                 }
3276+                 if  (!found) continue ;
3277+ 
3278+                 //  we erase the accepted tokens later, so we're looking for the same position next time
3279+                 //  increment by one because the next token will be generated
3280+                 slot.lookup_index  = draft_start_pos + 1 ;
3281+ 
3282+                 llama_tokens draft = std::vector (
3283+                     slot.prediction_tokens .begin () + draft_start_pos,
3284+                     slot.prediction_tokens .end ()
3285+                 );
3286+ 
3287+                 //  determine the max draft that fits the current slot state
3288+                 int  n_draft_max = slot.lookup_n_adaptive ;
3289+                 n_draft_max = std::min (n_draft_max, slot.n_ctx  - slot.n_past  - 2 );
3290+                 
3291+                 if  (slot.n_remaining  > 0 ) {
3292+                     n_draft_max = std::min (n_draft_max, slot.n_remaining  - 1 );
3293+                 }
3294+ 
3295+                 n_draft_max = std::min (n_draft_max, static_cast <int >(draft.size ()));
3296+                 //  NOTE: we use speculative.n_max here as the upper limit, but
3297+                 //  in general we want to allow large drafts, as opposed to when
3298+                 //  using a draft model. But this is linked to `slot.batch_spec`
3299+                 //  size also.
3300+                 n_draft_max = std::min (n_draft_max, slot.params .speculative .n_max );
3301+ 
3302+                 SLT_DBG (slot, " max possible draft: %d\n "  , n_draft_max);
3303+ 
3304+                 draft.resize (n_draft_max);
3305+ 
3306+                 llama_token id = slot.sampled ;
3307+ 
3308+                 //  construct the speculation batch
3309+                 common_batch_clear (slot.batch_spec );
3310+                 common_batch_add   (slot.batch_spec , id, slot.n_past , { slot.id  }, true );
3311+ 
3312+                 for  (size_t  i = 0 ; i < draft.size (); ++i) {
3313+                     common_batch_add (slot.batch_spec , draft[i], slot.n_past  + 1  + i, { slot.id  }, true );
3314+                 }
3315+ 
3316+                 llama_decode (ctx, slot.batch_spec );
3317+ 
3318+                 //  the accepted tokens from the speculation
3319+                 //  TODO can we stream these? Would be nice to reduce jankiness in UIs
3320+                 const  auto  ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
3321+ 
3322+                 const  auto  n_accepted = ids.size () - 1 ;
3323+                 slot.n_lookup_used    += n_accepted;
3324+ 
3325+                 if  (n_accepted > 0 ) {
3326+                    //  remove the prediction tokens that were used + the next token
3327+                    //  (because it will be generated)
3328+                    slot.prediction_tokens .erase (
3329+                        slot.prediction_tokens .begin () + draft_start_pos,
3330+                        std::min (
3331+                            slot.prediction_tokens .end (),
3332+                            slot.prediction_tokens .begin () + draft_start_pos + n_accepted + 1 
3333+                        )
3334+                    );
3335+                    if  (n_accepted < draft.size ()) {
3336+                         //  reset speculation as we didn't use the full draft
3337+                         slot.lookup_n_adaptive  = 1 ;
3338+                    }
3339+                 }
3340+ 
3341+                 for  (size_t  i = 0 ; i < ids.size (); ++i) {
3342+                     //  NOTE: we need to update these here to avoid stopping early
3343+                     slot.n_past ++;
3344+                     slot.n_decoded ++;
3345+                     completion_token_output result;
3346+ 
3347+                     result.tok           = ids[i];
3348+                     result.text_to_send  = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3349+                     result.prob          = 1 .0f ; //  set later
3350+ 
3351+                     //  TODO: set result.probs
3352+                     if  (!process_token (result, slot)) {
3353+                         //  release slot because of stop condition
3354+                         slot.release ();
3355+                         slot.print_timings ();
3356+                         send_final_response (slot);
3357+                         metrics.on_prediction (slot);
3358+                         break ;
3359+                     }
3360+                 }
3361+ 
3362+                 slot.cache_tokens .push_back (id);
3363+                 slot.cache_tokens .insert (slot.cache_tokens .end (), ids.begin (), ids.end () - 1 );
3364+ 
3365+                 llama_kv_cache_seq_rm (ctx, slot.id , slot.n_past , -1 );
3366+ 
3367+                 SLT_DBG (slot, " accepted %d/%d prediction tokens, new n_past = %d\n "  , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3368+             }
3369+ 
32203370            //  do speculative decoding
32213371            for  (auto  & slot : slots) {
32223372                if  (!slot.is_processing () || !slot.can_speculate ()) {
@@ -3838,10 +3988,17 @@ int main(int argc, char ** argv) {
38383988
38393989        try  {
38403990            const  auto  & prompt = data.at (" prompt"  );
3991+             const  auto  & prediction_obj = json_value (data, " prediction"  , json ());
3992+             const  auto  & prediction = json_value (prediction_obj, " content"  , std::string ());
38413993            //  TODO: this log can become very long, put it behind a flag or think about a more compact format
38423994            // SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
38433995
38443996            std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts (ctx_server.vocab , prompt, true , true );
3997+             std::vector<llama_tokens> tokenized_prediction;
3998+             if  (!prediction.empty ()) {
3999+                 tokenized_prediction = tokenize_input_prompts (ctx_server.vocab , prediction, true , true );
4000+             }
4001+ 
38454002            tasks.reserve (tokenized_prompts.size ());
38464003            for  (size_t  i = 0 ; i < tokenized_prompts.size (); i++) {
38474004                server_task task = server_task (type);
@@ -3850,6 +4007,10 @@ int main(int argc, char ** argv) {
38504007                task.index  = i;
38514008
38524009                task.prompt_tokens     = std::move (tokenized_prompts[i]);
4010+ 
4011+                 if  (!tokenized_prediction.empty ()) {
4012+                     task.prediction_tokens  = std::vector (tokenized_prediction[0 ].begin (), tokenized_prediction[0 ].end ());
4013+                 }
38534014                task.params            = server_task::params_from_json_cmpl (
38544015                                            ctx_server.ctx ,
38554016                                            ctx_server.params_base ,
0 commit comments