Skip to content

Commit 9dfbc69

Browse files
authored
Server: Add --draft-params to set draft model parameter via command line args (#932)
* Add command line argument for draft model * Remove second context of draft model * Format print * print usage if parsing -draft fails --------- Co-authored-by: firecoperana <firecoperana>
1 parent ad688e1 commit 9dfbc69

File tree

3 files changed

+78
-9
lines changed

3 files changed

+78
-9
lines changed

common/common.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,53 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string&
282282
// CLI argument parsing
283283
//
284284

285+
std::pair<int, char**> parse_command_line(const std::string& commandLine) {
286+
std::vector<std::string> tokens;
287+
std::string current;
288+
bool inQuotes = false;
289+
290+
for (size_t i = 0; i < commandLine.length(); i++) {
291+
char c = commandLine[i];
292+
293+
if (c == '\"') {
294+
inQuotes = !inQuotes;
295+
}
296+
else if (c == ' ' && !inQuotes) {
297+
if (!current.empty()) {
298+
tokens.push_back(current);
299+
current.clear();
300+
}
301+
}
302+
else {
303+
current += c;
304+
}
305+
}
306+
307+
if (!current.empty()) {
308+
tokens.push_back(current);
309+
}
310+
311+
int argc = static_cast<int>(tokens.size());
312+
char** argv = new char* [static_cast<size_t>(argc) + 1];
313+
314+
for (int i = 0; i < argc; i++) {
315+
argv[i] = new char[tokens[i].length() + 1];
316+
std::strcpy(argv[i], tokens[i].c_str());
317+
}
318+
argv[argc] = nullptr;
319+
return { argc, argv };
320+
}
321+
322+
void free_command_line(int argc, char** argv) {
323+
if (argv == nullptr) return;
324+
325+
for (int i = 0; i < argc; i++) {
326+
delete[] argv[i];
327+
}
328+
delete[] argv;
329+
}
330+
331+
285332
void gpt_params_handle_model_default(gpt_params & params) {
286333
if (!params.hf_repo.empty()) {
287334
// short-hand to avoid specifying --hf-file -> default it to --model
@@ -1254,6 +1301,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
12541301
params.cuda_params = argv[i];
12551302
return true;
12561303
}
1304+
if (arg == "-draft" || arg == "--draft-params") {
1305+
CHECK_ARG
1306+
params.draft_params = argv[i];
1307+
return true;
1308+
}
12571309
if (arg == "--cpu-moe" || arg == "-cmoe") {
12581310
params.tensor_buft_overrides.push_back({strdup("\\.ffn_(up|down|gate)_exps\\.weight"), ggml_backend_cpu_buffer_type()});
12591311
return true;
@@ -2081,7 +2133,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
20812133
options.push_back({ "backend" });
20822134
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });
20832135
options.push_back({ "*", "-cuda, --cuda-params", "comma separate list of cuda parameters" });
2084-
2136+
options.push_back({ "*", "-draft, --draft-params", "comma separate list of draft model parameters" });
20852137
if (llama_supports_mlock()) {
20862138
options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" });
20872139
}

common/common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ struct model_paths {
130130
struct gpt_params {
131131
std::string devices;
132132
std::string devices_draft;
133+
std::string draft_params;
133134

134135
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
135136

@@ -375,6 +376,8 @@ struct gpt_params {
375376
};
376377

377378

379+
std::pair<int, char**> parse_command_line(const std::string& commandLine);
380+
void free_command_line(int argc, char** argv);
378381

379382
void gpt_params_handle_hf_token(gpt_params & params);
380383
void gpt_params_parse_from_env(gpt_params & params);

examples/server/server.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,7 @@ struct server_context {
12501250
chat_templates = common_chat_templates_init(model, "chatml");
12511251
}
12521252

1253+
bool has_draft_model = !params.model_draft.empty() || !params.draft_params.empty();
12531254
std::string & mmproj_path = params.mmproj.path;
12541255
if (!mmproj_path.empty()) {
12551256
mtmd_context_params mparams = mtmd_context_params_default();
@@ -1274,24 +1275,37 @@ struct server_context {
12741275
// SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
12751276
//}
12761277

1277-
if (!params.model_draft.empty()) {
1278+
if (has_draft_model) {
12781279
LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal");
12791280
return false;
12801281
}
12811282
}
12821283
// Load draft model for speculative decoding if specified
1283-
if (!params.model_draft.empty()) {
1284-
LOG_INFO("loading draft model", {{"model", params.model_draft}});
1284+
if (has_draft_model) {
1285+
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
12851286

12861287
gpt_params params_dft;
12871288
params_dft.devices = params.devices_draft;
12881289
params_dft.model = params.model_draft;
1289-
params_dft.n_ctx = params.n_ctx_draft == 0 ? params.n_ctx / params.n_parallel : params.n_ctx_draft;
12901290
params_dft.n_gpu_layers = params.n_gpu_layers_draft;
1291-
params_dft.n_parallel = 1;
12921291
params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft;
12931292
params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft;
12941293
params_dft.flash_attn = params.flash_attn;
1294+
if (!params.draft_params.empty()) {
1295+
auto [argc, argv] = parse_command_line("llama-server "+params.draft_params);
1296+
if (!gpt_params_parse(argc, argv, params_dft)) {
1297+
gpt_params_print_usage(argc, argv, params_dft);
1298+
free_command_line(argc, argv);
1299+
return false;
1300+
};
1301+
free_command_line(argc, argv);
1302+
}
1303+
LOG_INFO("", { {"model", params_dft.model} });
1304+
if (params_dft.n_ctx == 0) {
1305+
params_dft.n_ctx = params.n_ctx_draft;
1306+
}
1307+
params_dft.n_ctx = params_dft.n_ctx == 0 ? params.n_ctx / params.n_parallel : params_dft.n_ctx;
1308+
params_dft.n_parallel = 1;
12951309

12961310
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft);
12971311

@@ -1361,8 +1375,8 @@ struct server_context {
13611375
// Initialize speculative decoding if a draft model is loaded
13621376
if (ctx_draft) {
13631377
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
1364-
1365-
slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft);
1378+
// slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft); // initialized twice
1379+
slot.ctx_dft = ctx_draft;
13661380
if (slot.ctx_dft == nullptr) {
13671381
LOG_ERROR("failed to create draft context", {});
13681382
return;
@@ -3010,7 +3024,7 @@ struct server_context {
30103024
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
30113025
new_tokens[i - n_discard] = new_tokens[i];
30123026
}
3013-
new_tokens.resize((int) prompt_tokens.size() - n_discard);
3027+
new_tokens.resize(prompt_tokens.size() - n_discard);
30143028
prompt_tokens.clear();
30153029
prompt_tokens.insert(new_tokens);
30163030
slot.truncated = true;

0 commit comments

Comments
 (0)