@@ -1250,6 +1250,7 @@ struct server_context {
12501250 chat_templates = common_chat_templates_init (model, " chatml" );
12511251 }
12521252
1253+ bool has_draft_model = !params.model_draft .empty () || !params.draft_params .empty ();
12531254 std::string & mmproj_path = params.mmproj .path ;
12541255 if (!mmproj_path.empty ()) {
12551256 mtmd_context_params mparams = mtmd_context_params_default ();
@@ -1274,24 +1275,37 @@ struct server_context {
12741275 // SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
12751276 // }
12761277
1277- if (!params. model_draft . empty () ) {
1278+ if (has_draft_model ) {
12781279 LOG_ERROR (" %s\n " , " err: speculative decode is not supported by multimodal" );
12791280 return false ;
12801281 }
12811282 }
12821283 // Load draft model for speculative decoding if specified
1283- if (!params. model_draft . empty () ) {
1284- LOG_INFO ( " loading draft model" , {{ " model " , params. model_draft }});
1284+ if (has_draft_model ) {
1285+ LLAMA_LOG_INFO ( " \n\n ================================== loading DRAFT model================================== \n\n " );
12851286
12861287 gpt_params params_dft;
12871288 params_dft.devices = params.devices_draft ;
12881289 params_dft.model = params.model_draft ;
1289- params_dft.n_ctx = params.n_ctx_draft == 0 ? params.n_ctx / params.n_parallel : params.n_ctx_draft ;
12901290 params_dft.n_gpu_layers = params.n_gpu_layers_draft ;
1291- params_dft.n_parallel = 1 ;
12921291 params_dft.cache_type_k = params.cache_type_k_draft .empty () ? params.cache_type_k : params.cache_type_k_draft ;
12931292 params_dft.cache_type_v = params.cache_type_v_draft .empty () ? params.cache_type_v : params.cache_type_v_draft ;
12941293 params_dft.flash_attn = params.flash_attn ;
1294+ if (!params.draft_params .empty ()) {
1295+ auto [argc, argv] = parse_command_line (" llama-server " +params.draft_params );
1296+ if (!gpt_params_parse (argc, argv, params_dft)) {
1297+ gpt_params_print_usage (argc, argv, params_dft);
1298+ free_command_line (argc, argv);
1299+ return false ;
1300+ };
1301+ free_command_line (argc, argv);
1302+ }
1303+ LOG_INFO (" " , { {" model" , params_dft.model } });
1304+ if (params_dft.n_ctx == 0 ) {
1305+ params_dft.n_ctx = params.n_ctx_draft ;
1306+ }
1307+ params_dft.n_ctx = params_dft.n_ctx == 0 ? params.n_ctx / params.n_parallel : params_dft.n_ctx ;
1308+ params_dft.n_parallel = 1 ;
12951309
12961310 llama_init_result llama_init_dft = llama_init_from_gpt_params (params_dft);
12971311
@@ -1361,8 +1375,8 @@ struct server_context {
13611375 // Initialize speculative decoding if a draft model is loaded
13621376 if (ctx_draft) {
13631377 slot.batch_spec = llama_batch_init (slot.params .speculative .n_max + 1 , 0 , 1 );
1364-
1365- slot.ctx_dft = llama_new_context_with_model (model_draft, cparams_dft) ;
1378+ // slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft); // initialized twice
1379+ slot.ctx_dft = ctx_draft ;
13661380 if (slot.ctx_dft == nullptr ) {
13671381 LOG_ERROR (" failed to create draft context" , {});
13681382 return ;
@@ -3010,7 +3024,7 @@ struct server_context {
30103024 for (size_t i = n_keep + n_discard; i < new_tokens.size (); i++) {
30113025 new_tokens[i - n_discard] = new_tokens[i];
30123026 }
3013- new_tokens.resize (( int ) prompt_tokens.size () - n_discard);
3027+ new_tokens.resize (prompt_tokens.size () - n_discard);
30143028 prompt_tokens.clear ();
30153029 prompt_tokens.insert (new_tokens);
30163030 slot.truncated = true ;
0 commit comments