|
18 | 18 |
|
19 | 19 | #define JSON_ASSERT GGML_ASSERT |
20 | 20 | #include <nlohmann/json.hpp> |
| 21 | +#include <yaml-cpp/yaml.h> |
21 | 22 |
|
22 | 23 | #include <algorithm> |
23 | 24 | #include <climits> |
@@ -65,6 +66,169 @@ static void write_file(const std::string & fname, const std::string & content) { |
65 | 66 | file.close(); |
66 | 67 | } |
67 | 68 |
|
| 69 | +bool common_params_load_from_yaml(const std::string & config_file, common_params & params) { |
| 70 | + try { |
| 71 | + YAML::Node config = YAML::LoadFile(config_file); |
| 72 | + |
| 73 | + // Model parameters |
| 74 | + if (config["model"]) { |
| 75 | + if (config["model"]["path"]) { |
| 76 | + params.model.path = config["model"]["path"].as<std::string>(); |
| 77 | + } |
| 78 | + if (config["model"]["url"]) { |
| 79 | + params.model.url = config["model"]["url"].as<std::string>(); |
| 80 | + } |
| 81 | + if (config["model"]["hf_repo"]) { |
| 82 | + params.model.hf_repo = config["model"]["hf_repo"].as<std::string>(); |
| 83 | + } |
| 84 | + if (config["model"]["hf_file"]) { |
| 85 | + params.model.hf_file = config["model"]["hf_file"].as<std::string>(); |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + // Basic parameters |
| 90 | + if (config["n_predict"]) params.n_predict = config["n_predict"].as<int32_t>(); |
| 91 | + if (config["n_ctx"]) params.n_ctx = config["n_ctx"].as<int32_t>(); |
| 92 | + if (config["n_batch"]) params.n_batch = config["n_batch"].as<int32_t>(); |
| 93 | + if (config["n_ubatch"]) params.n_ubatch = config["n_ubatch"].as<int32_t>(); |
| 94 | + if (config["n_keep"]) params.n_keep = config["n_keep"].as<int32_t>(); |
| 95 | + if (config["n_chunks"]) params.n_chunks = config["n_chunks"].as<int32_t>(); |
| 96 | + if (config["n_parallel"]) params.n_parallel = config["n_parallel"].as<int32_t>(); |
| 97 | + if (config["n_sequences"]) params.n_sequences = config["n_sequences"].as<int32_t>(); |
| 98 | + if (config["n_gpu_layers"]) params.n_gpu_layers = config["n_gpu_layers"].as<int32_t>(); |
| 99 | + if (config["main_gpu"]) params.main_gpu = config["main_gpu"].as<int32_t>(); |
| 100 | + if (config["verbosity"]) params.verbosity = config["verbosity"].as<int32_t>(); |
| 101 | + |
| 102 | + // String parameters |
| 103 | + if (config["prompt"]) params.prompt = config["prompt"].as<std::string>(); |
| 104 | + if (config["system_prompt"]) params.system_prompt = config["system_prompt"].as<std::string>(); |
| 105 | + if (config["prompt_file"]) params.prompt_file = config["prompt_file"].as<std::string>(); |
| 106 | + if (config["input_prefix"]) params.input_prefix = config["input_prefix"].as<std::string>(); |
| 107 | + if (config["input_suffix"]) params.input_suffix = config["input_suffix"].as<std::string>(); |
| 108 | + if (config["hf_token"]) params.hf_token = config["hf_token"].as<std::string>(); |
| 109 | + |
| 110 | + // Float parameters |
| 111 | + if (config["rope_freq_base"]) params.rope_freq_base = config["rope_freq_base"].as<float>(); |
| 112 | + if (config["rope_freq_scale"]) params.rope_freq_scale = config["rope_freq_scale"].as<float>(); |
| 113 | + if (config["yarn_ext_factor"]) params.yarn_ext_factor = config["yarn_ext_factor"].as<float>(); |
| 114 | + if (config["yarn_attn_factor"]) params.yarn_attn_factor = config["yarn_attn_factor"].as<float>(); |
| 115 | + if (config["yarn_beta_fast"]) params.yarn_beta_fast = config["yarn_beta_fast"].as<float>(); |
| 116 | + if (config["yarn_beta_slow"]) params.yarn_beta_slow = config["yarn_beta_slow"].as<float>(); |
| 117 | + |
| 118 | + // Boolean parameters |
| 119 | + if (config["interactive"]) params.interactive = config["interactive"].as<bool>(); |
| 120 | + if (config["interactive_first"]) params.interactive_first = config["interactive_first"].as<bool>(); |
| 121 | + if (config["conversation"]) { |
| 122 | + params.conversation_mode = config["conversation"].as<bool>() ? |
| 123 | + COMMON_CONVERSATION_MODE_ENABLED : COMMON_CONVERSATION_MODE_DISABLED; |
| 124 | + } |
| 125 | + if (config["use_color"]) params.use_color = config["use_color"].as<bool>(); |
| 126 | + if (config["simple_io"]) params.simple_io = config["simple_io"].as<bool>(); |
| 127 | + if (config["embedding"]) params.embedding = config["embedding"].as<bool>(); |
| 128 | + if (config["escape"]) params.escape = config["escape"].as<bool>(); |
| 129 | + if (config["multiline_input"]) params.multiline_input = config["multiline_input"].as<bool>(); |
| 130 | + if (config["cont_batching"]) params.cont_batching = config["cont_batching"].as<bool>(); |
| 131 | + if (config["flash_attn"]) { |
| 132 | + params.flash_attn_type = config["flash_attn"].as<bool>() ? |
| 133 | + LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED; |
| 134 | + } |
| 135 | + if (config["no_perf"]) params.no_perf = config["no_perf"].as<bool>(); |
| 136 | + if (config["ctx_shift"]) params.ctx_shift = config["ctx_shift"].as<bool>(); |
| 137 | + if (config["input_prefix_bos"]) params.input_prefix_bos = config["input_prefix_bos"].as<bool>(); |
| 138 | + if (config["use_mmap"]) params.use_mmap = config["use_mmap"].as<bool>(); |
| 139 | + if (config["use_mlock"]) params.use_mlock = config["use_mlock"].as<bool>(); |
| 140 | + if (config["verbose_prompt"]) params.verbose_prompt = config["verbose_prompt"].as<bool>(); |
| 141 | + if (config["display_prompt"]) params.display_prompt = config["display_prompt"].as<bool>(); |
| 142 | + if (config["no_kv_offload"]) params.no_kv_offload = config["no_kv_offload"].as<bool>(); |
| 143 | + if (config["warmup"]) params.warmup = config["warmup"].as<bool>(); |
| 144 | + if (config["check_tensors"]) params.check_tensors = config["check_tensors"].as<bool>(); |
| 145 | + |
| 146 | + // CPU parameters |
| 147 | + if (config["cpuparams"]) { |
| 148 | + const auto & cpu_config = config["cpuparams"]; |
| 149 | + if (cpu_config["n_threads"]) params.cpuparams.n_threads = cpu_config["n_threads"].as<int>(); |
| 150 | + if (cpu_config["strict_cpu"]) params.cpuparams.strict_cpu = cpu_config["strict_cpu"].as<bool>(); |
| 151 | + if (cpu_config["poll"]) params.cpuparams.poll = cpu_config["poll"].as<uint32_t>(); |
| 152 | + } |
| 153 | + |
| 154 | + // Sampling parameters |
| 155 | + if (config["sampling"]) { |
| 156 | + const auto & sampling_config = config["sampling"]; |
| 157 | + if (sampling_config["seed"]) params.sampling.seed = sampling_config["seed"].as<uint32_t>(); |
| 158 | + if (sampling_config["n_prev"]) params.sampling.n_prev = sampling_config["n_prev"].as<int32_t>(); |
| 159 | + if (sampling_config["n_probs"]) params.sampling.n_probs = sampling_config["n_probs"].as<int32_t>(); |
| 160 | + if (sampling_config["min_keep"]) params.sampling.min_keep = sampling_config["min_keep"].as<int32_t>(); |
| 161 | + if (sampling_config["top_k"]) params.sampling.top_k = sampling_config["top_k"].as<int32_t>(); |
| 162 | + if (sampling_config["top_p"]) params.sampling.top_p = sampling_config["top_p"].as<float>(); |
| 163 | + if (sampling_config["min_p"]) params.sampling.min_p = sampling_config["min_p"].as<float>(); |
| 164 | + if (sampling_config["xtc_probability"]) params.sampling.xtc_probability = sampling_config["xtc_probability"].as<float>(); |
| 165 | + if (sampling_config["xtc_threshold"]) params.sampling.xtc_threshold = sampling_config["xtc_threshold"].as<float>(); |
| 166 | + if (sampling_config["typ_p"]) params.sampling.typ_p = sampling_config["typ_p"].as<float>(); |
| 167 | + if (sampling_config["temp"]) params.sampling.temp = sampling_config["temp"].as<float>(); |
| 168 | + if (sampling_config["dynatemp_range"]) params.sampling.dynatemp_range = sampling_config["dynatemp_range"].as<float>(); |
| 169 | + if (sampling_config["dynatemp_exponent"]) params.sampling.dynatemp_exponent = sampling_config["dynatemp_exponent"].as<float>(); |
| 170 | + if (sampling_config["penalty_last_n"]) params.sampling.penalty_last_n = sampling_config["penalty_last_n"].as<int32_t>(); |
| 171 | + if (sampling_config["penalty_repeat"]) params.sampling.penalty_repeat = sampling_config["penalty_repeat"].as<float>(); |
| 172 | + if (sampling_config["penalty_freq"]) params.sampling.penalty_freq = sampling_config["penalty_freq"].as<float>(); |
| 173 | + if (sampling_config["penalty_present"]) params.sampling.penalty_present = sampling_config["penalty_present"].as<float>(); |
| 174 | + if (sampling_config["dry_multiplier"]) params.sampling.dry_multiplier = sampling_config["dry_multiplier"].as<float>(); |
| 175 | + if (sampling_config["dry_base"]) params.sampling.dry_base = sampling_config["dry_base"].as<float>(); |
| 176 | + if (sampling_config["dry_allowed_length"]) params.sampling.dry_allowed_length = sampling_config["dry_allowed_length"].as<int32_t>(); |
| 177 | + if (sampling_config["dry_penalty_last_n"]) params.sampling.dry_penalty_last_n = sampling_config["dry_penalty_last_n"].as<int32_t>(); |
| 178 | + if (sampling_config["mirostat"]) params.sampling.mirostat = sampling_config["mirostat"].as<int32_t>(); |
| 179 | + if (sampling_config["top_n_sigma"]) params.sampling.top_n_sigma = sampling_config["top_n_sigma"].as<float>(); |
| 180 | + if (sampling_config["mirostat_tau"]) params.sampling.mirostat_tau = sampling_config["mirostat_tau"].as<float>(); |
| 181 | + if (sampling_config["mirostat_eta"]) params.sampling.mirostat_eta = sampling_config["mirostat_eta"].as<float>(); |
| 182 | + if (sampling_config["ignore_eos"]) params.sampling.ignore_eos = sampling_config["ignore_eos"].as<bool>(); |
| 183 | + if (sampling_config["no_perf"]) params.sampling.no_perf = sampling_config["no_perf"].as<bool>(); |
| 184 | + if (sampling_config["timing_per_token"]) params.sampling.timing_per_token = sampling_config["timing_per_token"].as<bool>(); |
| 185 | + if (sampling_config["grammar"]) params.sampling.grammar = sampling_config["grammar"].as<std::string>(); |
| 186 | + if (sampling_config["grammar_lazy"]) params.sampling.grammar_lazy = sampling_config["grammar_lazy"].as<bool>(); |
| 187 | + |
| 188 | + if (sampling_config["dry_sequence_breakers"]) { |
| 189 | + params.sampling.dry_sequence_breakers.clear(); |
| 190 | + for (const auto & breaker : sampling_config["dry_sequence_breakers"]) { |
| 191 | + params.sampling.dry_sequence_breakers.push_back(breaker.as<std::string>()); |
| 192 | + } |
| 193 | + } |
| 194 | + } |
| 195 | + |
| 196 | + // Speculative parameters |
| 197 | + if (config["speculative"]) { |
| 198 | + const auto & spec_config = config["speculative"]; |
| 199 | + if (spec_config["n_ctx"]) params.speculative.n_ctx = spec_config["n_ctx"].as<int32_t>(); |
| 200 | + if (spec_config["n_max"]) params.speculative.n_max = spec_config["n_max"].as<int32_t>(); |
| 201 | + if (spec_config["n_min"]) params.speculative.n_min = spec_config["n_min"].as<int32_t>(); |
| 202 | + if (spec_config["n_gpu_layers"]) params.speculative.n_gpu_layers = spec_config["n_gpu_layers"].as<int32_t>(); |
| 203 | + if (spec_config["p_split"]) params.speculative.p_split = spec_config["p_split"].as<float>(); |
| 204 | + if (spec_config["p_min"]) params.speculative.p_min = spec_config["p_min"].as<float>(); |
| 205 | + |
| 206 | + if (spec_config["model"]) { |
| 207 | + const auto & model_config = spec_config["model"]; |
| 208 | + if (model_config["path"]) params.speculative.model.path = model_config["path"].as<std::string>(); |
| 209 | + if (model_config["url"]) params.speculative.model.url = model_config["url"].as<std::string>(); |
| 210 | + if (model_config["hf_repo"]) params.speculative.model.hf_repo = model_config["hf_repo"].as<std::string>(); |
| 211 | + if (model_config["hf_file"]) params.speculative.model.hf_file = model_config["hf_file"].as<std::string>(); |
| 212 | + } |
| 213 | + } |
| 214 | + |
| 215 | + if (config["antiprompt"]) { |
| 216 | + params.antiprompt.clear(); |
| 217 | + for (const auto & antiprompt : config["antiprompt"]) { |
| 218 | + params.antiprompt.push_back(antiprompt.as<std::string>()); |
| 219 | + } |
| 220 | + } |
| 221 | + |
| 222 | + return true; |
| 223 | + } catch (const YAML::Exception & e) { |
| 224 | + LOG_ERR("Error parsing YAML config file '%s': %s\n", config_file.c_str(), e.what()); |
| 225 | + return false; |
| 226 | + } catch (const std::exception & e) { |
| 227 | + LOG_ERR("Error loading YAML config file '%s': %s\n", config_file.c_str(), e.what()); |
| 228 | + return false; |
| 229 | + } |
| 230 | +} |
| 231 | + |
68 | 232 | common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) { |
69 | 233 | this->examples = std::move(examples); |
70 | 234 | return *this; |
@@ -1227,6 +1391,20 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e |
1227 | 1391 | ctx_arg.params = params_org; |
1228 | 1392 | return false; |
1229 | 1393 | } |
| 1394 | + |
| 1395 | + // Load YAML config if specified |
| 1396 | + if (!ctx_arg.params.config_file.empty()) { |
| 1397 | + if (!common_params_load_from_yaml(ctx_arg.params.config_file, ctx_arg.params)) { |
| 1398 | + ctx_arg.params = params_org; |
| 1399 | + return false; |
| 1400 | + } |
| 1401 | + |
| 1402 | + if (!common_params_parse_ex(argc, argv, ctx_arg)) { |
| 1403 | + ctx_arg.params = params_org; |
| 1404 | + return false; |
| 1405 | + } |
| 1406 | + } |
| 1407 | + |
1230 | 1408 | if (ctx_arg.params.usage) { |
1231 | 1409 | common_params_print_usage(ctx_arg); |
1232 | 1410 | if (ctx_arg.print_usage) { |
@@ -1317,6 +1495,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex |
1317 | 1495 | params.completion = true; |
1318 | 1496 | } |
1319 | 1497 | )); |
| 1498 | + add_opt(common_arg( |
| 1499 | + {"--config"}, "FNAME", |
| 1500 | + "path to a YAML config file (default: none)", |
| 1501 | + [](common_params & params, const std::string & value) { |
| 1502 | + params.config_file = value; |
| 1503 | + } |
| 1504 | + )); |
1320 | 1505 | add_opt(common_arg( |
1321 | 1506 | {"--verbose-prompt"}, |
1322 | 1507 | string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"), |
|
0 commit comments