Skip to content

Commit bb5e924

Browse files
author
Piotr Stankiewicz
committed
Allow passing GGUF splits via repeated --model args
In case a segmented GGUF file does not follow the specified naming convention, it will not be possible to load it. So, modify the argument parser to allow repeated --model args to be specified on the CLI, and in such case assume those are GGUF splits given in order. Signed-off-by: Piotr Stankiewicz <[email protected]>
1 parent 5ba36f6 commit bb5e924

File tree

17 files changed

+187
-59
lines changed

17 files changed

+187
-59
lines changed

common/arg.cpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,12 @@ static bool common_download_model(
496496
LOG_ERR("%s: invalid model url\n", __func__);
497497
return false;
498498
}
499+
if (model.paths.size() != 1) {
500+
LOG_ERR("%s: model url can only be specified with one path\n", __func__);
501+
return false;
502+
}
499503

500-
if (!common_download_file_single(model.url, model.path, bearer_token, offline)) {
504+
if (!common_download_file_single(model.url, model.paths[0], bearer_token, offline)) {
501505
return false;
502506
}
503507

@@ -508,9 +512,9 @@ static bool common_download_model(
508512
/*.no_alloc = */ true,
509513
/*.ctx = */ NULL,
510514
};
511-
auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
515+
auto * ctx_gguf = gguf_init_from_file(model.paths[0].c_str(), gguf_params);
512516
if (!ctx_gguf) {
513-
LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str());
517+
LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.paths[0].c_str());
514518
return false;
515519
}
516520

@@ -529,8 +533,8 @@ static bool common_download_model(
529533
// Verify the first split file format
530534
// and extract split URL and PATH prefixes
531535
{
532-
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
533-
LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
536+
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.paths[0].c_str(), 0, n_split)) {
537+
LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.paths[0].c_str(), n_split);
534538
return false;
535539
}
536540

@@ -548,7 +552,7 @@ static bool common_download_model(
548552
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
549553
llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
550554

551-
if (std::string(split_path) == model.path) {
555+
if (std::string(split_path) == model.paths[0]) {
552556
continue; // skip the already downloaded file
553557
}
554558

@@ -798,7 +802,7 @@ static handle_model_result common_params_handle_model(
798802
if (!model.hf_repo.empty()) {
799803
// short-hand to avoid specifying --hf-file -> default it to --model
800804
if (model.hf_file.empty()) {
801-
if (model.path.empty()) {
805+
if (model.paths.empty()) {
802806
auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline);
803807
if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
804808
exit(1); // built without CURL, error message already printed
@@ -811,30 +815,30 @@ static handle_model_result common_params_handle_model(
811815
result.mmproj.hf_file = auto_detected.mmprojFile;
812816
}
813817
} else {
814-
model.hf_file = model.path;
818+
model.hf_file = model.paths[0];
815819
}
816820
}
817821

818822
std::string model_endpoint = get_model_endpoint();
819823
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
820824
// make sure model path is present (for caching purposes)
821-
if (model.path.empty()) {
825+
if (model.paths.empty()) {
822826
// this is to avoid different repo having same file name, or same file name in different subdirs
823827
std::string filename = model.hf_repo + "_" + model.hf_file;
824828
// to make sure we don't have any slashes in the filename
825829
string_replace_all(filename, "/", "_");
826-
model.path = fs_get_cache_file(filename);
830+
model.paths.push_back(fs_get_cache_file(filename));
827831
}
828832

829833
} else if (!model.url.empty()) {
830-
if (model.path.empty()) {
834+
if (model.paths.empty()) {
831835
auto f = string_split<std::string>(model.url, '#').front();
832836
f = string_split<std::string>(f, '?').front();
833-
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
837+
model.paths.push_back(fs_get_cache_file(string_split<std::string>(f, '/').back()));
834838
}
835839

836-
} else if (model.path.empty()) {
837-
model.path = model_path_default;
840+
} else if (model.paths.empty() && !model_path_default.empty()) {
841+
model.paths.push_back(model_path_default);
838842
}
839843
}
840844

@@ -986,7 +990,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
986990
auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH, params.offline);
987991
if (params.no_mmproj) {
988992
params.mmproj = {};
989-
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
993+
} else if (res.found_mmproj && params.mmproj.paths.empty() && params.mmproj.url.empty()) {
990994
// optionally, handle mmproj model when -hf is specified
991995
params.mmproj = res.mmproj;
992996
}
@@ -2285,7 +2289,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
22852289
"path to a multimodal projector file. see tools/mtmd/README.md\n"
22862290
"note: if -hf is used, this argument can be omitted",
22872291
[](common_params & params, const std::string & value) {
2288-
params.mmproj.path = value;
2292+
if (params.mmproj.paths.empty()) {
2293+
params.mmproj.paths.push_back(value);
2294+
} else {
2295+
params.mmproj.paths[0] = value;
2296+
}
22892297
}
22902298
).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ"));
22912299
add_opt(common_arg(
@@ -2597,7 +2605,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
25972605
"or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH
25982606
),
25992607
[](common_params & params, const std::string & value) {
2600-
params.model.path = value;
2608+
params.model.paths.push_back(value);
26012609
}
26022610
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL"));
26032611
add_opt(common_arg(
@@ -3330,7 +3338,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
33303338
{"-md", "--model-draft"}, "FNAME",
33313339
"draft model for speculative decoding (default: unused)",
33323340
[](common_params & params, const std::string & value) {
3333-
params.speculative.model.path = value;
3341+
params.speculative.model.paths.push_back(value);
33343342
}
33353343
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
33363344
add_opt(common_arg(
@@ -3371,7 +3379,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
33713379
{"-mv", "--model-vocoder"}, "FNAME",
33723380
"vocoder model for audio generation (default: unused)",
33733381
[](common_params & params, const std::string & value) {
3374-
params.vocoder.model.path = value;
3382+
params.vocoder.model.paths.push_back(value);
33753383
}
33763384
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
33773385
add_opt(common_arg(

common/common.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -912,10 +912,24 @@ std::string fs_get_cache_file(const std::string & filename) {
912912
struct common_init_result common_init_from_params(common_params & params) {
913913
common_init_result iparams;
914914
auto mparams = common_model_params_to_llama(params);
915+
llama_model * model = NULL;
916+
917+
if (params.model.paths.empty()) {
918+
LOG_ERR("%s: failed to load model 'model path not specified'\n", __func__);
919+
return iparams;
920+
} else if (params.model.paths.size() == 1) {
921+
model = llama_model_load_from_file(params.model.paths[0].c_str(), mparams);
922+
} else {
923+
std::vector<const char *> paths;
924+
paths.reserve(params.model.paths.size());
925+
for (const auto & path : params.model.paths) {
926+
paths.push_back(path.c_str());
927+
}
928+
model = llama_model_load_from_splits(paths.data(), paths.size(), mparams);
929+
}
915930

916-
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
917931
if (model == NULL) {
918-
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
932+
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.paths[0].c_str());
919933
return iparams;
920934
}
921935

@@ -925,7 +939,7 @@ struct common_init_result common_init_from_params(common_params & params) {
925939

926940
llama_context * lctx = llama_init_from_model(model, cparams);
927941
if (lctx == NULL) {
928-
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
942+
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.paths[0].c_str());
929943
llama_model_free(model);
930944
return iparams;
931945
}

common/common.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,10 @@ struct common_params_sampling {
190190
};
191191

192192
struct common_params_model {
193-
std::string path = ""; // model local path // NOLINT
194-
std::string url = ""; // model url to download // NOLINT
195-
std::string hf_repo = ""; // HF repo // NOLINT
196-
std::string hf_file = ""; // HF file // NOLINT
193+
std::vector<std::string> paths = {}; // model local path // NOLINT
194+
std::string url = ""; // model url to download // NOLINT
195+
std::string hf_repo = ""; // HF repo // NOLINT
196+
std::string hf_file = ""; // HF file // NOLINT
197197
};
198198

199199
struct common_params_speculative {

examples/batched/batched.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,20 @@ int main(int argc, char ** argv) {
4141

4242
llama_model_params model_params = common_model_params_to_llama(params);
4343

44-
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
44+
llama_model * model = NULL;
45+
if (params.model.paths.empty()) {
46+
LOG_ERR("%s: failed to load model 'model path not specified'\n", __func__);
47+
return 1;
48+
} else if (params.model.paths.size() == 1) {
49+
model = llama_model_load_from_file(params.model.paths[0].c_str(), model_params);
50+
} else {
51+
std::vector<const char *> paths;
52+
paths.reserve(params.model.paths.size());
53+
for (const auto & path : params.model.paths) {
54+
paths.push_back(path.c_str());
55+
}
56+
model = llama_model_load_from_splits(paths.data(), paths.size(), model_params);
57+
}
4558

4659
if (model == NULL) {
4760
LOG_ERR("%s: error: unable to load model\n" , __func__);

examples/diffusion/diffusion-cli.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,9 +548,23 @@ int main(int argc, char ** argv) {
548548
model_params.use_mlock = params.use_mlock;
549549
model_params.check_tensors = params.check_tensors;
550550

551-
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
551+
llama_model * model = NULL;
552+
if (params.model.paths.empty()) {
553+
LOG_ERR("error: failed to load model 'model path not specified'\n");
554+
return 1;
555+
} else if (params.model.paths.size() == 1) {
556+
model = llama_model_load_from_file(params.model.paths[0].c_str(), model_params);
557+
} else {
558+
std::vector<const char *> paths;
559+
paths.reserve(params.model.paths.size());
560+
for (const auto & path : params.model.paths) {
561+
paths.push_back(path.c_str());
562+
}
563+
model = llama_model_load_from_splits(paths.data(), paths.size(), model_params);
564+
}
565+
552566
if (!model) {
553-
LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
567+
LOG_ERR("error: failed to load model '%s'\n", params.model.paths[0].c_str());
554568
return 1;
555569
}
556570

examples/gritlm/gritlm.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,20 @@ int main(int argc, char * argv[]) {
168168

169169
llama_backend_init();
170170

171-
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
171+
llama_model * model = NULL;
172+
if (params.model.paths.empty()) {
173+
fprintf(stderr, "failed to load model 'model path not specified'\n");
174+
return 1;
175+
} else if (params.model.paths.size() == 1) {
176+
model = llama_model_load_from_file(params.model.paths[0].c_str(), mparams);
177+
} else {
178+
std::vector<const char *> paths;
179+
paths.reserve(params.model.paths.size());
180+
for (const auto & path : params.model.paths) {
181+
paths.push_back(path.c_str());
182+
}
183+
model = llama_model_load_from_splits(paths.data(), paths.size(), mparams);
184+
}
172185

173186
// create generation context
174187
llama_context * ctx = llama_init_from_model(model, cparams);

examples/parallel/parallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ int main(int argc, char ** argv) {
495495
params.prompt_file = "used built-in defaults";
496496
}
497497
LOG_INF("External prompt file: \033[32m%s\033[0m\n", params.prompt_file.c_str());
498-
LOG_INF("Model and path used: \033[32m%s\033[0m\n\n", params.model.path.c_str());
498+
LOG_INF("Model and path used: \033[32m%s\033[0m\n\n", params.model.paths[0].c_str());
499499

500500
LOG_INF("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6);
501501
LOG_INF("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6);

examples/passkey/passkey.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,20 @@ int main(int argc, char ** argv) {
6464

6565
llama_model_params model_params = common_model_params_to_llama(params);
6666

67-
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
67+
llama_model * model;
68+
if (params.model.paths.empty()) {
69+
LOG_ERR("%s: failed to load model 'model path not specified'\n", __func__);
70+
return 1;
71+
} else if (params.model.paths.size() == 1) {
72+
model = llama_model_load_from_file(params.model.paths[0].c_str(), model_params);
73+
} else {
74+
std::vector<const char *> paths;
75+
paths.reserve(params.model.paths.size());
76+
for (const auto & path : params.model.paths) {
77+
paths.push_back(path.c_str());
78+
}
79+
model = llama_model_load_from_splits(paths.data(), paths.size(), model_params);
80+
}
6881

6982
if (model == NULL) {
7083
LOG_ERR("%s: unable to load model\n" , __func__);

examples/speculative-simple/speculative-simple.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ int main(int argc, char ** argv) {
2424

2525
common_init();
2626

27-
if (params.speculative.model.path.empty()) {
27+
if (params.speculative.model.paths.empty()) {
2828
LOG_ERR("%s: --model-draft is required\n", __func__);
2929
return 1;
3030
}
@@ -67,7 +67,7 @@ int main(int argc, char ** argv) {
6767
ctx_dft = llama_init_dft.context.get();
6868

6969
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
70-
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
70+
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.paths[0].c_str(), params.model.paths[0].c_str());
7171
}
7272

7373
// Tokenize the prompt

examples/speculative/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ int main(int argc, char ** argv) {
4646

4747
common_init();
4848

49-
if (params.speculative.model.path.empty()) {
49+
if (params.speculative.model.paths.empty()) {
5050
LOG_ERR("%s: --model-draft is required\n", __func__);
5151
return 1;
5252
}

0 commit comments

Comments
 (0)