|
7 | 7 | #include "log.h" |
8 | 8 | #include "sampling.h" |
9 | 9 |
|
| 10 | +#include <yaml-cpp/yaml.h> |
| 11 | + |
10 | 12 | // fix problem with std::min and std::max |
11 | 13 | #if defined(_WIN32) |
12 | 14 | #define WIN32_LEAN_AND_MEAN |
|
41 | 43 |
|
42 | 44 | using json = nlohmann::ordered_json; |
43 | 45 |
|
| 46 | +// YAML configuration parsing functions |
| 47 | +static void parse_yaml_sampling(const YAML::Node& node, common_params_sampling& sampling) { |
| 48 | + if (node["seed"]) sampling.seed = node["seed"].as<uint32_t>(); |
| 49 | + if (node["n_prev"]) sampling.n_prev = node["n_prev"].as<int32_t>(); |
| 50 | + if (node["n_probs"]) sampling.n_probs = node["n_probs"].as<int32_t>(); |
| 51 | + if (node["min_keep"]) sampling.min_keep = node["min_keep"].as<int32_t>(); |
| 52 | + if (node["top_k"]) sampling.top_k = node["top_k"].as<int32_t>(); |
| 53 | + if (node["top_p"]) sampling.top_p = node["top_p"].as<float>(); |
| 54 | + if (node["min_p"]) sampling.min_p = node["min_p"].as<float>(); |
| 55 | + if (node["xtc_probability"]) sampling.xtc_probability = node["xtc_probability"].as<float>(); |
| 56 | + if (node["xtc_threshold"]) sampling.xtc_threshold = node["xtc_threshold"].as<float>(); |
| 57 | + if (node["typ_p"]) sampling.typ_p = node["typ_p"].as<float>(); |
| 58 | + if (node["temp"]) sampling.temp = node["temp"].as<float>(); |
| 59 | + if (node["dynatemp_range"]) sampling.dynatemp_range = node["dynatemp_range"].as<float>(); |
| 60 | + if (node["dynatemp_exponent"]) sampling.dynatemp_exponent = node["dynatemp_exponent"].as<float>(); |
| 61 | + if (node["penalty_last_n"]) sampling.penalty_last_n = node["penalty_last_n"].as<int32_t>(); |
| 62 | + if (node["penalty_repeat"]) sampling.penalty_repeat = node["penalty_repeat"].as<float>(); |
| 63 | + if (node["penalty_freq"]) sampling.penalty_freq = node["penalty_freq"].as<float>(); |
| 64 | + if (node["penalty_present"]) sampling.penalty_present = node["penalty_present"].as<float>(); |
| 65 | + if (node["dry_multiplier"]) sampling.dry_multiplier = node["dry_multiplier"].as<float>(); |
| 66 | + if (node["dry_base"]) sampling.dry_base = node["dry_base"].as<float>(); |
| 67 | + if (node["dry_allowed_length"]) sampling.dry_allowed_length = node["dry_allowed_length"].as<int32_t>(); |
| 68 | + if (node["dry_penalty_last_n"]) sampling.dry_penalty_last_n = node["dry_penalty_last_n"].as<int32_t>(); |
| 69 | + if (node["mirostat"]) sampling.mirostat = node["mirostat"].as<int32_t>(); |
| 70 | + if (node["top_n_sigma"]) sampling.top_n_sigma = node["top_n_sigma"].as<float>(); |
| 71 | + if (node["mirostat_tau"]) sampling.mirostat_tau = node["mirostat_tau"].as<float>(); |
| 72 | + if (node["mirostat_eta"]) sampling.mirostat_eta = node["mirostat_eta"].as<float>(); |
| 73 | + if (node["ignore_eos"]) sampling.ignore_eos = node["ignore_eos"].as<bool>(); |
| 74 | + if (node["no_perf"]) sampling.no_perf = node["no_perf"].as<bool>(); |
| 75 | + if (node["timing_per_token"]) sampling.timing_per_token = node["timing_per_token"].as<bool>(); |
| 76 | + if (node["grammar"]) sampling.grammar = node["grammar"].as<std::string>(); |
| 77 | + if (node["grammar_lazy"]) sampling.grammar_lazy = node["grammar_lazy"].as<bool>(); |
| 78 | + |
| 79 | + if (node["dry_sequence_breakers"] && node["dry_sequence_breakers"].IsSequence()) { |
| 80 | + sampling.dry_sequence_breakers.clear(); |
| 81 | + for (const auto& breaker : node["dry_sequence_breakers"]) { |
| 82 | + sampling.dry_sequence_breakers.push_back(breaker.as<std::string>()); |
| 83 | + } |
| 84 | + } |
| 85 | +} |
| 86 | + |
| 87 | +static void parse_yaml_model(const YAML::Node& node, common_params_model& model) { |
| 88 | + if (node["path"]) model.path = node["path"].as<std::string>(); |
| 89 | + if (node["url"]) model.url = node["url"].as<std::string>(); |
| 90 | + if (node["hf_repo"]) model.hf_repo = node["hf_repo"].as<std::string>(); |
| 91 | + if (node["hf_file"]) model.hf_file = node["hf_file"].as<std::string>(); |
| 92 | +} |
| 93 | + |
| 94 | +static void parse_yaml_speculative(const YAML::Node& node, common_params_speculative& spec) { |
| 95 | + if (node["n_ctx"]) spec.n_ctx = node["n_ctx"].as<int32_t>(); |
| 96 | + if (node["n_max"]) spec.n_max = node["n_max"].as<int32_t>(); |
| 97 | + if (node["n_min"]) spec.n_min = node["n_min"].as<int32_t>(); |
| 98 | + if (node["n_gpu_layers"]) spec.n_gpu_layers = node["n_gpu_layers"].as<int32_t>(); |
| 99 | + if (node["p_split"]) spec.p_split = node["p_split"].as<float>(); |
| 100 | + if (node["p_min"]) spec.p_min = node["p_min"].as<float>(); |
| 101 | + if (node["cache_type_k"]) { |
| 102 | + std::string cache_type = node["cache_type_k"].as<std::string>(); |
| 103 | + if (cache_type == "f16") spec.cache_type_k = GGML_TYPE_F16; |
| 104 | + else if (cache_type == "f32") spec.cache_type_k = GGML_TYPE_F32; |
| 105 | + else if (cache_type == "q4_0") spec.cache_type_k = GGML_TYPE_Q4_0; |
| 106 | + else if (cache_type == "q4_1") spec.cache_type_k = GGML_TYPE_Q4_1; |
| 107 | + else if (cache_type == "q5_0") spec.cache_type_k = GGML_TYPE_Q5_0; |
| 108 | + else if (cache_type == "q5_1") spec.cache_type_k = GGML_TYPE_Q5_1; |
| 109 | + else if (cache_type == "q8_0") spec.cache_type_k = GGML_TYPE_Q8_0; |
| 110 | + } |
| 111 | + if (node["cache_type_v"]) { |
| 112 | + std::string cache_type = node["cache_type_v"].as<std::string>(); |
| 113 | + if (cache_type == "f16") spec.cache_type_v = GGML_TYPE_F16; |
| 114 | + else if (cache_type == "f32") spec.cache_type_v = GGML_TYPE_F32; |
| 115 | + else if (cache_type == "q4_0") spec.cache_type_v = GGML_TYPE_Q4_0; |
| 116 | + else if (cache_type == "q4_1") spec.cache_type_v = GGML_TYPE_Q4_1; |
| 117 | + else if (cache_type == "q5_0") spec.cache_type_v = GGML_TYPE_Q5_0; |
| 118 | + else if (cache_type == "q5_1") spec.cache_type_v = GGML_TYPE_Q5_1; |
| 119 | + else if (cache_type == "q8_0") spec.cache_type_v = GGML_TYPE_Q8_0; |
| 120 | + } |
| 121 | + if (node["model"]) { |
| 122 | + parse_yaml_model(node["model"], spec.model); |
| 123 | + } |
| 124 | +} |
| 125 | + |
| 126 | +static void parse_yaml_vocoder(const YAML::Node& node, common_params_vocoder& vocoder) { |
| 127 | + if (node["speaker_file"]) vocoder.speaker_file = node["speaker_file"].as<std::string>(); |
| 128 | + if (node["use_guide_tokens"]) vocoder.use_guide_tokens = node["use_guide_tokens"].as<bool>(); |
| 129 | + if (node["model"]) { |
| 130 | + parse_yaml_model(node["model"], vocoder.model); |
| 131 | + } |
| 132 | +} |
| 133 | + |
| 134 | +static void parse_yaml_diffusion(const YAML::Node& node, common_params_diffusion& diffusion) { |
| 135 | + if (node["steps"]) diffusion.steps = node["steps"].as<int32_t>(); |
| 136 | + if (node["visual_mode"]) diffusion.visual_mode = node["visual_mode"].as<bool>(); |
| 137 | + if (node["eps"]) diffusion.eps = node["eps"].as<float>(); |
| 138 | + if (node["block_length"]) diffusion.block_length = node["block_length"].as<int32_t>(); |
| 139 | + if (node["algorithm"]) diffusion.algorithm = node["algorithm"].as<int32_t>(); |
| 140 | + if (node["alg_temp"]) diffusion.alg_temp = node["alg_temp"].as<float>(); |
| 141 | + if (node["cfg_scale"]) diffusion.cfg_scale = node["cfg_scale"].as<float>(); |
| 142 | + if (node["add_gumbel_noise"]) diffusion.add_gumbel_noise = node["add_gumbel_noise"].as<bool>(); |
| 143 | +} |
| 144 | + |
| 145 | +static bool load_yaml_config(const std::string& config_path, common_params& params) { |
| 146 | + try { |
| 147 | + YAML::Node config = YAML::LoadFile(config_path); |
| 148 | + |
| 149 | + // Parse main parameters |
| 150 | + if (config["n_predict"]) params.n_predict = config["n_predict"].as<int32_t>(); |
| 151 | + if (config["n_ctx"]) params.n_ctx = config["n_ctx"].as<int32_t>(); |
| 152 | + if (config["n_batch"]) params.n_batch = config["n_batch"].as<int32_t>(); |
| 153 | + if (config["n_ubatch"]) params.n_ubatch = config["n_ubatch"].as<int32_t>(); |
| 154 | + if (config["n_keep"]) params.n_keep = config["n_keep"].as<int32_t>(); |
| 155 | + if (config["n_chunks"]) params.n_chunks = config["n_chunks"].as<int32_t>(); |
| 156 | + if (config["n_parallel"]) params.n_parallel = config["n_parallel"].as<int32_t>(); |
| 157 | + if (config["n_sequences"]) params.n_sequences = config["n_sequences"].as<int32_t>(); |
| 158 | + if (config["grp_attn_n"]) params.grp_attn_n = config["grp_attn_n"].as<int32_t>(); |
| 159 | + if (config["grp_attn_w"]) params.grp_attn_w = config["grp_attn_w"].as<int32_t>(); |
| 160 | + if (config["n_print"]) params.n_print = config["n_print"].as<int32_t>(); |
| 161 | + if (config["rope_freq_base"]) params.rope_freq_base = config["rope_freq_base"].as<float>(); |
| 162 | + if (config["rope_freq_scale"]) params.rope_freq_scale = config["rope_freq_scale"].as<float>(); |
| 163 | + if (config["yarn_ext_factor"]) params.yarn_ext_factor = config["yarn_ext_factor"].as<float>(); |
| 164 | + if (config["yarn_attn_factor"]) params.yarn_attn_factor = config["yarn_attn_factor"].as<float>(); |
| 165 | + if (config["yarn_beta_fast"]) params.yarn_beta_fast = config["yarn_beta_fast"].as<float>(); |
| 166 | + if (config["yarn_beta_slow"]) params.yarn_beta_slow = config["yarn_beta_slow"].as<float>(); |
| 167 | + if (config["yarn_orig_ctx"]) params.yarn_orig_ctx = config["yarn_orig_ctx"].as<int32_t>(); |
| 168 | + if (config["n_gpu_layers"]) params.n_gpu_layers = config["n_gpu_layers"].as<int32_t>(); |
| 169 | + if (config["main_gpu"]) params.main_gpu = config["main_gpu"].as<int32_t>(); |
| 170 | + |
| 171 | + // Parse string parameters |
| 172 | + if (config["model_alias"]) params.model_alias = config["model_alias"].as<std::string>(); |
| 173 | + if (config["hf_token"]) params.hf_token = config["hf_token"].as<std::string>(); |
| 174 | + if (config["prompt"]) params.prompt = config["prompt"].as<std::string>(); |
| 175 | + if (config["system_prompt"]) params.system_prompt = config["system_prompt"].as<std::string>(); |
| 176 | + if (config["prompt_file"]) params.prompt_file = config["prompt_file"].as<std::string>(); |
| 177 | + if (config["path_prompt_cache"]) params.path_prompt_cache = config["path_prompt_cache"].as<std::string>(); |
| 178 | + if (config["input_prefix"]) params.input_prefix = config["input_prefix"].as<std::string>(); |
| 179 | + if (config["input_suffix"]) params.input_suffix = config["input_suffix"].as<std::string>(); |
| 180 | + if (config["lookup_cache_static"]) params.lookup_cache_static = config["lookup_cache_static"].as<std::string>(); |
| 181 | + if (config["lookup_cache_dynamic"]) params.lookup_cache_dynamic = config["lookup_cache_dynamic"].as<std::string>(); |
| 182 | + if (config["logits_file"]) params.logits_file = config["logits_file"].as<std::string>(); |
| 183 | + |
| 184 | + // Parse boolean parameters |
| 185 | + if (config["lora_init_without_apply"]) params.lora_init_without_apply = config["lora_init_without_apply"].as<bool>(); |
| 186 | + if (config["offline"]) params.offline = config["offline"].as<bool>(); |
| 187 | + |
| 188 | + // Parse integer parameters |
| 189 | + if (config["verbosity"]) params.verbosity = config["verbosity"].as<int32_t>(); |
| 190 | + if (config["control_vector_layer_start"]) params.control_vector_layer_start = config["control_vector_layer_start"].as<int32_t>(); |
| 191 | + if (config["control_vector_layer_end"]) params.control_vector_layer_end = config["control_vector_layer_end"].as<int32_t>(); |
| 192 | + if (config["ppl_stride"]) params.ppl_stride = config["ppl_stride"].as<int32_t>(); |
| 193 | + if (config["ppl_output_type"]) params.ppl_output_type = config["ppl_output_type"].as<int32_t>(); |
| 194 | + |
| 195 | + // Parse array parameters |
| 196 | + if (config["in_files"] && config["in_files"].IsSequence()) { |
| 197 | + params.in_files.clear(); |
| 198 | + for (const auto& file : config["in_files"]) { |
| 199 | + params.in_files.push_back(file.as<std::string>()); |
| 200 | + } |
| 201 | + } |
| 202 | + |
| 203 | + if (config["antiprompt"] && config["antiprompt"].IsSequence()) { |
| 204 | + params.antiprompt.clear(); |
| 205 | + for (const auto& prompt : config["antiprompt"]) { |
| 206 | + params.antiprompt.push_back(prompt.as<std::string>()); |
| 207 | + } |
| 208 | + } |
| 209 | + |
| 210 | + if (config["sampling"]) { |
| 211 | + parse_yaml_sampling(config["sampling"], params.sampling); |
| 212 | + } |
| 213 | + |
| 214 | + if (config["model"]) { |
| 215 | + parse_yaml_model(config["model"], params.model); |
| 216 | + } |
| 217 | + |
| 218 | + if (config["speculative"]) { |
| 219 | + parse_yaml_speculative(config["speculative"], params.speculative); |
| 220 | + } |
| 221 | + |
| 222 | + if (config["vocoder"]) { |
| 223 | + parse_yaml_vocoder(config["vocoder"], params.vocoder); |
| 224 | + } |
| 225 | + |
| 226 | + if (config["diffusion"]) { |
| 227 | + parse_yaml_diffusion(config["diffusion"], params.diffusion); |
| 228 | + } |
| 229 | + |
| 230 | + return true; |
| 231 | + } catch (const YAML::Exception& e) { |
| 232 | + fprintf(stderr, "YAML parsing error: %s\n", e.what()); |
| 233 | + return false; |
| 234 | + } catch (const std::exception& e) { |
| 235 | + fprintf(stderr, "Error loading YAML config: %s\n", e.what()); |
| 236 | + return false; |
| 237 | + } |
| 238 | +} |
| 239 | + |
44 | 240 | std::initializer_list<enum llama_example> mmproj_examples = { |
45 | 241 | LLAMA_EXAMPLE_MTMD, |
46 | 242 | LLAMA_EXAMPLE_SERVER, |
@@ -1223,6 +1419,17 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e |
1223 | 1419 | const common_params params_org = ctx_arg.params; // the example can modify the default params |
1224 | 1420 |
|
1225 | 1421 | try { |
| 1422 | + for (int i = 1; i < argc; i++) { |
| 1423 | + if (strcmp(argv[i], "--config") == 0 && i + 1 < argc) { |
| 1424 | + if (!load_yaml_config(argv[i + 1], ctx_arg.params)) { |
| 1425 | + fprintf(stderr, "Failed to load YAML config: %s\n", argv[i + 1]); |
| 1426 | + ctx_arg.params = params_org; |
| 1427 | + return false; |
| 1428 | + } |
| 1429 | + break; |
| 1430 | + } |
| 1431 | + } |
| 1432 | + |
1226 | 1433 | if (!common_params_parse_ex(argc, argv, ctx_arg)) { |
1227 | 1434 | ctx_arg.params = params_org; |
1228 | 1435 | return false; |
@@ -1294,6 +1501,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex |
1294 | 1501 | }; |
1295 | 1502 |
|
1296 | 1503 |
|
| 1504 | + add_opt(common_arg( |
| 1505 | + {"--config"}, |
| 1506 | + "FNAME", |
| 1507 | + "path to YAML configuration file", |
| 1508 | + [](common_params & params, const std::string & value) { |
| 1509 | + params.config_file = value; |
| 1510 | + } |
| 1511 | + )); |
| 1512 | + |
1297 | 1513 | add_opt(common_arg( |
1298 | 1514 | {"-h", "--help", "--usage"}, |
1299 | 1515 | "print usage and exit", |
|
0 commit comments