Skip to content

Commit ee629d5

Browse files
committed
arg : add model catalog
1 parent 8c83449 commit ee629d5

File tree

4 files changed

+286
-170
lines changed

4 files changed

+286
-170
lines changed

common/arg.cpp

Lines changed: 98 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "log.h"
66
#include "sampling.h"
77
#include "chat.h"
8+
#include "catalog.h"
89

910
// fix problem with std::min and std::max
1011
#if defined(_WIN32)
@@ -608,15 +609,18 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
608609
*
609610
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
610611
*/
611-
static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) {
612+
static struct common_hf_file_res common_get_hf_file(
613+
const std::string & hf_repo_with_tag,
614+
const std::string & bearer_token,
615+
const std::string & model_endpoint) {
612616
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
613617
std::string tag = parts.size() > 1 ? parts.back() : "latest";
614618
std::string hf_repo = parts[0];
615619
if (string_split<std::string>(hf_repo, '/').size() != 2) {
616620
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
617621
}
618622

619-
std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
623+
std::string url = model_endpoint + "v2/" + hf_repo + "/manifests/" + tag;
620624

621625
// headers
622626
std::vector<std::string> headers;
@@ -715,7 +719,7 @@ static bool common_download_model(
715719
return false;
716720
}
717721

718-
static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &) {
722+
static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &, const std::string &) {
719723
LOG_ERR("error: built without CURL, cannot download model from the internet\n");
720724
return {};
721725
}
@@ -742,15 +746,15 @@ struct handle_model_result {
742746
static handle_model_result common_params_handle_model(
743747
struct common_params_model & model,
744748
const std::string & bearer_token,
745-
const std::string & model_path_default) {
749+
const std::string & model_endpoint) {
746750
handle_model_result result;
747751
// handle pre-fill default model path and url based on hf_repo and hf_file
748752
{
749753
if (!model.hf_repo.empty()) {
750754
// short-hand to avoid specifying --hf-file -> default it to --model
751755
if (model.hf_file.empty()) {
752756
if (model.path.empty()) {
753-
auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token);
757+
auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, model_endpoint);
754758
if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
755759
exit(1); // built without CURL, error message already printed
756760
}
@@ -766,7 +770,6 @@ static handle_model_result common_params_handle_model(
766770
}
767771
}
768772

769-
std::string model_endpoint = get_model_endpoint();
770773
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
771774
// make sure model path is present (for caching purposes)
772775
if (model.path.empty()) {
@@ -784,8 +787,6 @@ static handle_model_result common_params_handle_model(
784787
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
785788
}
786789

787-
} else if (model.path.empty()) {
788-
model.path = model_path_default;
789790
}
790791
}
791792

@@ -835,7 +836,6 @@ static std::string get_all_kv_cache_types() {
835836
//
836837

837838
static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) {
838-
std::string arg;
839839
const std::string arg_prefix = "--";
840840
common_params & params = ctx_arg.params;
841841

@@ -875,16 +875,91 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
875875
}
876876
};
877877

878+
// normalize args
879+
std::string input_pos_arg;
880+
std::vector<std::string> input_opt_args;
881+
input_opt_args.reserve(argc - 1);
878882
for (int i = 1; i < argc; i++) {
879883
const std::string arg_prefix = "--";
880884

881885
std::string arg = argv[i];
886+
if (arg_to_options.find(arg) == arg_to_options.end()) {
887+
// if we don't have a match, check if this can be a positional argument
888+
if (input_pos_arg.empty()) {
889+
input_pos_arg = std::move(arg);
890+
continue;
891+
} else {
892+
// if the positional argument is already set, we cannot have another one
893+
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
894+
}
895+
}
896+
897+
// normalize the argument (only applied to optional args)
882898
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
883899
std::replace(arg.begin(), arg.end(), '_', '-');
884900
}
885-
if (arg_to_options.find(arg) == arg_to_options.end()) {
886-
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
901+
input_opt_args.emplace_back(arg);
902+
}
903+
904+
// handle positional argument (we only support one positional argument)
905+
// the logic is as follow:
906+
// 1. we try to find the model name in the catalog
907+
// 2. if not found, we check the prefix protocol://
908+
// 3. if no protocol found, we assume it is a local file
909+
{
910+
bool is_handled = false;
911+
// check catalog
912+
for (auto & entry : model_catalog) {
913+
if (input_pos_arg == entry.name) {
914+
is_handled = true;
915+
// check if the model support current example
916+
bool is_supported = false;
917+
for (auto & ex : entry.examples) {
918+
if (ctx_arg.ex == ex) {
919+
is_supported = true;
920+
break;
921+
}
922+
}
923+
if (is_supported) {
924+
entry.handler(params);
925+
} else {
926+
LOG_ERR("error: model '%s' is not supported by this tool\n", entry.name);
927+
exit(1);
928+
}
929+
break;
930+
}
931+
}
932+
// check protocol
933+
// for contributors: if you want to add a new protocol,
934+
// please add make sure it support either /resolve/main or registry API
935+
// see common_params_handle_model() to understand it is handled
936+
// note: we don't support ollama because it usually contains their proprietary model (incompatible with llama.cpp)
937+
if (!is_handled) {
938+
const std::string & arg = input_pos_arg;
939+
// check if it is a URL
940+
if (string_starts_with(arg, "http://") || string_starts_with(arg, "https://")) {
941+
params.model.url = arg;
942+
} else if (string_starts_with(arg, "hf://")) {
943+
// hugging face repo
944+
params.model.hf_repo = arg.substr(5);
945+
} else if (string_starts_with(arg, "hf-mirror://")) {
946+
// hugging face mirror
947+
params.custom_model_endpoint = "hf-mirror.com";
948+
params.model.hf_repo = arg.substr(12);
949+
} else if (string_starts_with(arg, "ms://")) {
950+
// modelscope
951+
params.custom_model_endpoint = "modelscope.cn";
952+
params.model.hf_repo = arg.substr(5);
953+
} else {
954+
// assume it is a local file
955+
params.model.path = arg;
956+
}
887957
}
958+
}
959+
960+
// handle optional args
961+
for (size_t i = 1; i < input_opt_args.size(); i++) {
962+
const std::string & arg = input_opt_args[i];
888963
auto opt = *arg_to_options[arg];
889964
if (opt.has_value_from_env()) {
890965
fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str());
@@ -934,7 +1009,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
9341009

9351010
// handle model and download
9361011
{
937-
auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH);
1012+
std::string model_endpoint = params.get_model_endpoint();
1013+
auto res = common_params_handle_model(params.model, params.hf_token, model_endpoint);
9381014
if (params.no_mmproj) {
9391015
params.mmproj = {};
9401016
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
@@ -944,12 +1020,12 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
9441020
// only download mmproj if the current example is using it
9451021
for (auto & ex : mmproj_examples) {
9461022
if (ctx_arg.ex == ex) {
947-
common_params_handle_model(params.mmproj, params.hf_token, "");
1023+
common_params_handle_model(params.mmproj, params.hf_token, model_endpoint);
9481024
break;
9491025
}
9501026
}
951-
common_params_handle_model(params.speculative.model, params.hf_token, "");
952-
common_params_handle_model(params.vocoder.model, params.hf_token, "");
1027+
common_params_handle_model(params.speculative.model, params.hf_token, model_endpoint);
1028+
common_params_handle_model(params.vocoder.model, params.hf_token, model_endpoint);
9531029
}
9541030

9551031
if (params.escape) {
@@ -985,6 +1061,13 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
9851061
));
9861062
}
9871063

1064+
if (params.model.path.empty()) {
1065+
throw std::invalid_argument(
1066+
"model path is empty\n"
1067+
"please specify a model file or use one from the catalog\n"
1068+
"use --catalog to see the list of available models\n");
1069+
}
1070+
9881071
return true;
9891072
}
9901073

@@ -3178,145 +3261,5 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
31783261
}
31793262
).set_examples({LLAMA_EXAMPLE_TTS}));
31803263

3181-
// model-specific
3182-
add_opt(common_arg(
3183-
{"--tts-oute-default"},
3184-
string_format("use default OuteTTS models (note: can download weights from the internet)"),
3185-
[](common_params & params) {
3186-
params.model.hf_repo = "OuteAI/OuteTTS-0.2-500M-GGUF";
3187-
params.model.hf_file = "OuteTTS-0.2-500M-Q8_0.gguf";
3188-
params.vocoder.model.hf_repo = "ggml-org/WavTokenizer";
3189-
params.vocoder.model.hf_file = "WavTokenizer-Large-75-F16.gguf";
3190-
}
3191-
).set_examples({LLAMA_EXAMPLE_TTS}));
3192-
3193-
add_opt(common_arg(
3194-
{"--embd-bge-small-en-default"},
3195-
string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"),
3196-
[](common_params & params) {
3197-
params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
3198-
params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf";
3199-
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
3200-
params.embd_normalize = 2;
3201-
params.n_ctx = 512;
3202-
params.verbose_prompt = true;
3203-
params.embedding = true;
3204-
}
3205-
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
3206-
3207-
add_opt(common_arg(
3208-
{"--embd-e5-small-en-default"},
3209-
string_format("use default e5-small-v2 model (note: can download weights from the internet)"),
3210-
[](common_params & params) {
3211-
params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
3212-
params.model.hf_file = "e5-small-v2-q8_0.gguf";
3213-
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
3214-
params.embd_normalize = 2;
3215-
params.n_ctx = 512;
3216-
params.verbose_prompt = true;
3217-
params.embedding = true;
3218-
}
3219-
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
3220-
3221-
add_opt(common_arg(
3222-
{"--embd-gte-small-default"},
3223-
string_format("use default gte-small model (note: can download weights from the internet)"),
3224-
[](common_params & params) {
3225-
params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
3226-
params.model.hf_file = "gte-small-q8_0.gguf";
3227-
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
3228-
params.embd_normalize = 2;
3229-
params.n_ctx = 512;
3230-
params.verbose_prompt = true;
3231-
params.embedding = true;
3232-
}
3233-
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
3234-
3235-
add_opt(common_arg(
3236-
{"--fim-qwen-1.5b-default"},
3237-
string_format("use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet)"),
3238-
[](common_params & params) {
3239-
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF";
3240-
params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf";
3241-
params.port = 8012;
3242-
params.n_gpu_layers = 99;
3243-
params.flash_attn = true;
3244-
params.n_ubatch = 1024;
3245-
params.n_batch = 1024;
3246-
params.n_ctx = 0;
3247-
params.n_cache_reuse = 256;
3248-
}
3249-
).set_examples({LLAMA_EXAMPLE_SERVER}));
3250-
3251-
add_opt(common_arg(
3252-
{"--fim-qwen-3b-default"},
3253-
string_format("use default Qwen 2.5 Coder 3B (note: can download weights from the internet)"),
3254-
[](common_params & params) {
3255-
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF";
3256-
params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf";
3257-
params.port = 8012;
3258-
params.n_gpu_layers = 99;
3259-
params.flash_attn = true;
3260-
params.n_ubatch = 1024;
3261-
params.n_batch = 1024;
3262-
params.n_ctx = 0;
3263-
params.n_cache_reuse = 256;
3264-
}
3265-
).set_examples({LLAMA_EXAMPLE_SERVER}));
3266-
3267-
add_opt(common_arg(
3268-
{"--fim-qwen-7b-default"},
3269-
string_format("use default Qwen 2.5 Coder 7B (note: can download weights from the internet)"),
3270-
[](common_params & params) {
3271-
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
3272-
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
3273-
params.port = 8012;
3274-
params.n_gpu_layers = 99;
3275-
params.flash_attn = true;
3276-
params.n_ubatch = 1024;
3277-
params.n_batch = 1024;
3278-
params.n_ctx = 0;
3279-
params.n_cache_reuse = 256;
3280-
}
3281-
).set_examples({LLAMA_EXAMPLE_SERVER}));
3282-
3283-
add_opt(common_arg(
3284-
{"--fim-qwen-7b-spec"},
3285-
string_format("use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet)"),
3286-
[](common_params & params) {
3287-
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
3288-
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
3289-
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
3290-
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
3291-
params.speculative.n_gpu_layers = 99;
3292-
params.port = 8012;
3293-
params.n_gpu_layers = 99;
3294-
params.flash_attn = true;
3295-
params.n_ubatch = 1024;
3296-
params.n_batch = 1024;
3297-
params.n_ctx = 0;
3298-
params.n_cache_reuse = 256;
3299-
}
3300-
).set_examples({LLAMA_EXAMPLE_SERVER}));
3301-
3302-
add_opt(common_arg(
3303-
{"--fim-qwen-14b-spec"},
3304-
string_format("use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet)"),
3305-
[](common_params & params) {
3306-
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
3307-
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
3308-
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
3309-
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
3310-
params.speculative.n_gpu_layers = 99;
3311-
params.port = 8012;
3312-
params.n_gpu_layers = 99;
3313-
params.flash_attn = true;
3314-
params.n_ubatch = 1024;
3315-
params.n_batch = 1024;
3316-
params.n_ctx = 0;
3317-
params.n_cache_reuse = 256;
3318-
}
3319-
).set_examples({LLAMA_EXAMPLE_SERVER}));
3320-
33213264
return ctx_arg;
33223265
}

0 commit comments

Comments
 (0)