22
33#include " arg.h"
44#include " common.h"
5- #include " log.h"
6- #include " sampling.h"
75#include " json-schema-to-grammar.h"
86#include " llama.h"
7+ #include " log.h"
8+ #include " sampling.h"
9+ #include " speculative.h"
910
1011// Change JSON_ASSERT from assert() to GGML_ASSERT:
1112#define JSON_ASSERT GGML_ASSERT
@@ -127,6 +128,12 @@ struct server_slot {
127128 int id;
128129 int id_task = -1 ;
129130
131+ llama_batch batch_spec;
132+
133+ llama_context * ctx_dft = nullptr ;
134+
135+ common_speculative * spec = nullptr ;
136+
130137 // the index relative to completion multi-task request
131138 size_t index = 0 ;
132139
@@ -591,11 +598,14 @@ struct server_response {
591598};
592599
593600struct server_context {
601+ common_params params;
602+
594603 llama_model * model = nullptr ;
595604 llama_context * ctx = nullptr ;
596605 std::vector<common_lora_adapter_container> loras;
597606
598- common_params params;
607+ llama_model * model_dft = nullptr ;
608+ llama_context_params cparams_dft;
599609
600610 llama_batch batch = {};
601611
@@ -628,17 +638,33 @@ struct server_context {
628638 model = nullptr ;
629639 }
630640
641+ if (model_dft) {
642+ llama_free_model (model_dft);
643+ model_dft = nullptr ;
644+ }
645+
631646 // Clear any sampling context
632647 for (server_slot & slot : slots) {
633648 if (slot.smpl != nullptr ) {
649+ llama_free (slot.ctx_dft );
650+ slot.ctx_dft = nullptr ;
651+
652+ common_speculative_free (slot.spec );
653+ slot.spec = nullptr ;
654+
634655 common_sampler_free (slot.smpl );
656+ slot.smpl = nullptr ;
657+
658+ llama_batch_free (slot.batch_spec );
635659 }
636660 }
637661
638662 llama_batch_free (batch);
639663 }
640664
641665 bool load_model (const common_params & params_) {
666+ SRV_INF (" loading model '%s'\n " , params_.model .c_str ());
667+
642668 params = params_;
643669
644670 common_init_result llama_init = common_init_from_params (params);
@@ -657,6 +683,40 @@ struct server_context {
657683 add_bos_token = llama_add_bos_token (model);
658684 has_eos_token = !llama_add_eos_token (model);
659685
686+ if (!params.model_draft .empty ()) {
687+ SRV_INF (" loading draft model '%s'\n " , params_.model_draft .c_str ());
688+
689+ auto params_dft = params;
690+
691+ params_dft.model = params.model_draft ;
692+ params_dft.n_gpu_layers = params.n_gpu_layers_draft ;
693+
694+ if (params.draft_cpuparams .n_threads > 0 ) {
695+ params_dft.cpuparams .n_threads = params.draft_cpuparams .n_threads ;
696+ }
697+
698+ params_dft.cpuparams_batch .n_threads = params.draft_cpuparams_batch .n_threads ;
699+
700+ common_init_result llama_init_dft = common_init_from_params (params_dft);
701+
702+ model_dft = llama_init_dft.model ;
703+
704+ if (model_dft == nullptr ) {
705+ SRV_ERR (" failed to load draft model, '%s'\n " , params.model_draft .c_str ());
706+ return false ;
707+ }
708+
709+ if (!common_speculative_are_compatible (ctx, llama_init_dft.context )) {
710+ SRV_ERR (" the draft model '%s' is not compatible with the target model '%s'\n " , params.model_draft .c_str (), params.model .c_str ());
711+ return false ;
712+ }
713+
714+ cparams_dft = common_context_params_to_llama (params);
715+
716+ // the context is not needed - we will create one for each slot
717+ llama_free (llama_init_dft.context );
718+ }
719+
660720 return true ;
661721 }
662722
@@ -685,6 +745,22 @@ struct server_context {
685745 slot.n_ctx = n_ctx_slot;
686746 slot.n_predict = params.n_predict ;
687747
748+ if (model_dft) {
749+ slot.ctx_dft = llama_new_context_with_model (model_dft, cparams_dft);
750+ if (slot.ctx_dft == nullptr ) {
751+ SRV_ERR (" %s" , " failed to create draft context\n " );
752+ return ;
753+ }
754+
755+ slot.spec = common_speculative_init (slot.ctx_dft );
756+ if (slot.spec == nullptr ) {
757+ SRV_ERR (" %s" , " failed to create speculator\n " );
758+ return ;
759+ }
760+
761+ slot.batch_spec = llama_batch_init (params.n_draft + 1 , 0 , 1 );
762+ }
763+
688764 SLT_INF (slot, " new slot n_ctx_slot = %d\n " , slot.n_ctx );
689765
690766 slot.sparams = params.sampling ;
@@ -2168,38 +2244,108 @@ struct server_context {
21682244 continue ; // continue loop of slots
21692245 }
21702246
2171- completion_token_output result;
2172- const llama_token id = common_sampler_sample (slot.smpl , ctx, slot.i_batch - i);
2247+ llama_token id;
21732248
2174- common_sampler_accept (slot.smpl , id, true );
2249+ {
2250+ completion_token_output result;
21752251
2176- slot.n_decoded += 1 ;
2177- if (slot.n_decoded == 1 ) {
2178- slot.t_start_generation = ggml_time_us ();
2179- slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt ) / 1e3 ;
2180- metrics.on_prompt_eval (slot);
2181- }
2252+ id = common_sampler_sample (slot.smpl , ctx, slot.i_batch - i);
21822253
2183- result. tok = id ;
2254+ common_sampler_accept (slot. smpl , id, true ) ;
21842255
2185- const auto * cur_p = common_sampler_get_candidates (slot.smpl );
2256+ slot.n_decoded += 1 ;
2257+ if (slot.n_decoded == 1 ) {
2258+ slot.t_start_generation = ggml_time_us ();
2259+ slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt ) / 1e3 ;
2260+ metrics.on_prompt_eval (slot);
2261+ }
21862262
2187- for (size_t i = 0 ; i < (size_t ) slot.sparams .n_probs ; ++i) {
2188- result.probs .push_back ({
2189- cur_p->data [i].id ,
2190- i >= cur_p->size ? 0 .0f : cur_p->data [i].p ,
2191- });
2192- }
2263+ result.tok = id;
21932264
2194- if (!process_token (result, slot)) {
2195- // release slot because of stop condition
2196- slot.release ();
2197- slot.print_timings ();
2198- send_final_response (slot);
2199- metrics.on_prediction (slot);
2265+ const auto * cur_p = common_sampler_get_candidates (slot.smpl );
2266+
2267+ for (size_t i = 0 ; i < (size_t ) slot.sparams .n_probs ; ++i) {
2268+ result.probs .push_back ({
2269+ cur_p->data [i].id ,
2270+ i >= cur_p->size ? 0 .0f : cur_p->data [i].p ,
2271+ });
2272+ }
2273+
2274+ if (!process_token (result, slot)) {
2275+ // release slot because of stop condition
2276+ slot.release ();
2277+ slot.print_timings ();
2278+ send_final_response (slot);
2279+ metrics.on_prediction (slot);
2280+ }
22002281 }
22012282
22022283 slot.i_batch = -1 ;
2284+
2285+ if (slot.ctx_dft ) {
2286+ struct common_speculative_params params_spec;
2287+ params_spec.n_draft = params.n_draft ;
2288+ params_spec.n_reuse = 256 ;
2289+ params_spec.p_min = 0 .9f ;
2290+
2291+ llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
2292+
2293+ if (draft.size () > params.n_draft_min ) {
2294+ common_batch_clear (slot.batch_spec );
2295+ common_batch_add (slot.batch_spec , id, slot.n_past ++, { slot.id }, true );
2296+
2297+ for (size_t i = 0 ; i < draft.size (); ++i) {
2298+ common_batch_add (slot.batch_spec , draft[i], slot.n_past + i, { slot.id }, true );
2299+ }
2300+
2301+ llama_decode (ctx, slot.batch_spec );
2302+
2303+ const auto ids = common_sampler_sample_n (slot.smpl , ctx, draft);
2304+
2305+ slot.n_past += ids.size () - 1 ;
2306+
2307+ slot.cache_tokens .push_back (id);
2308+
2309+ for (size_t i = 0 ; i < ids.size (); ++i) {
2310+ completion_token_output result;
2311+
2312+ id = ids[i];
2313+
2314+ common_sampler_accept (slot.smpl , id, true );
2315+
2316+ slot.n_decoded += 1 ;
2317+ if (slot.n_decoded == 1 ) {
2318+ slot.t_start_generation = ggml_time_us ();
2319+ slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt ) / 1e3 ;
2320+ metrics.on_prompt_eval (slot);
2321+ }
2322+
2323+ result.tok = id;
2324+
2325+ const auto * cur_p = common_sampler_get_candidates (slot.smpl );
2326+
2327+ for (size_t i = 0 ; i < (size_t ) slot.sparams .n_probs ; ++i) {
2328+ result.probs .push_back ({
2329+ cur_p->data [i].id ,
2330+ i >= cur_p->size ? 0 .0f : cur_p->data [i].p ,
2331+ });
2332+ }
2333+
2334+ if (!process_token (result, slot)) {
2335+ // release slot because of stop condition
2336+ slot.release ();
2337+ slot.print_timings ();
2338+ send_final_response (slot);
2339+ metrics.on_prediction (slot);
2340+ break ;
2341+ }
2342+ }
2343+
2344+ llama_kv_cache_seq_rm (ctx, slot.id , slot.n_past , -1 );
2345+
2346+ slot.cache_tokens .insert (slot.cache_tokens .end (), ids.begin (), ids.end () - 1 );
2347+ }
2348+ }
22032349 }
22042350 }
22052351
0 commit comments