Skip to content

Commit 6ef22f0

Browse files
committed
common : add -hfd option for the draft model
1 parent aea8ddd commit 6ef22f0

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

common/arg.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
299299
}
300300

301301
// TODO: refactor model params in a common struct
302-
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token);
303-
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token);
302+
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token);
303+
common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token);
304+
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token);
304305

305306
if (params.escape) {
306307
string_process_escapes(params.prompt);
@@ -1629,6 +1630,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
16291630
params.hf_repo = value;
16301631
}
16311632
).set_env("LLAMA_ARG_HF_REPO"));
1633+
add_opt(common_arg(
1634+
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
1635+
"Same as --hf-repo, but for the draft model (default: unused)",
1636+
[](common_params & params, const std::string & value) {
1637+
params.speculative.hf_repo = value;
1638+
}
1639+
).set_env("LLAMA_ARG_HF_REPO"));
16321640
add_opt(common_arg(
16331641
{"-hff", "--hf-file"}, "FILE",
16341642
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",

common/common.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,11 @@ struct common_params_speculative {
175175
struct cpu_params cpuparams;
176176
struct cpu_params cpuparams_batch;
177177

178-
std::string model = ""; // draft model for speculative decoding // NOLINT
178+
std::string hf_repo = ""; // HF repo // NOLINT
179+
std::string hf_file = ""; // HF file // NOLINT
180+
181+
std::string model = ""; // draft model for speculative decoding // NOLINT
182+
std::string model_url = ""; // model url to download // NOLINT
179183
};
180184

181185
struct common_params_vocoder {
@@ -508,12 +512,14 @@ struct llama_model * common_load_model_from_url(
508512
const std::string & local_path,
509513
const std::string & hf_token,
510514
const struct llama_model_params & params);
515+
511516
struct llama_model * common_load_model_from_hf(
512517
const std::string & repo,
513518
const std::string & remote_path,
514519
const std::string & local_path,
515520
const std::string & hf_token,
516521
const struct llama_model_params & params);
522+
517523
std::pair<std::string, std::string> common_get_hf_file(
518524
const std::string & hf_repo_with_tag,
519525
const std::string & hf_token);

0 commit comments

Comments
 (0)