|
| 1 | +#include "config.h" |
| 2 | +#include "log.h" |
| 3 | + |
| 4 | +#include <yaml-cpp/yaml.h> |
| 5 | +#include <filesystem> |
| 6 | +#include <set> |
| 7 | +#include <sstream> |
| 8 | +#include <stdexcept> |
| 9 | + |
| 10 | +namespace fs = std::filesystem; |
| 11 | + |
| 12 | +static std::set<std::string> get_valid_keys() { |
| 13 | + return { |
| 14 | + "model.path", "model.url", "model.hf_repo", "model.hf_file", |
| 15 | + "model_alias", "hf_token", "prompt", "system_prompt", "prompt_file", |
| 16 | + "n_predict", "n_ctx", "n_batch", "n_ubatch", "n_keep", "n_chunks", |
| 17 | + "n_parallel", "n_sequences", "grp_attn_n", "grp_attn_w", "n_print", |
| 18 | + "rope_freq_base", "rope_freq_scale", "yarn_ext_factor", "yarn_attn_factor", |
| 19 | + "yarn_beta_fast", "yarn_beta_slow", "yarn_orig_ctx", |
| 20 | + "n_gpu_layers", "main_gpu", "split_mode", "pooling_type", "attention_type", |
| 21 | + "flash_attn_type", "numa", "use_mmap", "use_mlock", "verbose_prompt", |
| 22 | + "display_prompt", "no_kv_offload", "warmup", "check_tensors", "no_op_offload", |
| 23 | + "no_extra_bufts", "cache_type_k", "cache_type_v", "conversation_mode", |
| 24 | + "simple_io", "interactive", "interactive_first", "input_prefix", "input_suffix", |
| 25 | + "logits_file", "path_prompt_cache", "antiprompt", "in_files", "kv_overrides", |
| 26 | + "tensor_buft_overrides", "lora_adapters", "control_vectors", "image", "seed", |
| 27 | + "sampling.seed", "sampling.n_prev", "sampling.n_probs", "sampling.min_keep", |
| 28 | + "sampling.top_k", "sampling.top_p", "sampling.min_p", "sampling.xtc_probability", |
| 29 | + "sampling.xtc_threshold", "sampling.typ_p", "sampling.temp", "sampling.dynatemp_range", |
| 30 | + "sampling.dynatemp_exponent", "sampling.penalty_last_n", "sampling.penalty_repeat", |
| 31 | + "sampling.penalty_freq", "sampling.penalty_present", "sampling.dry_multiplier", |
| 32 | + "sampling.dry_base", "sampling.dry_allowed_length", "sampling.dry_penalty_last_n", |
| 33 | + "sampling.mirostat", "sampling.mirostat_tau", "sampling.mirostat_eta", |
| 34 | + "sampling.top_n_sigma", "sampling.ignore_eos", "sampling.no_perf", |
| 35 | + "sampling.timing_per_token", "sampling.dry_sequence_breakers", "sampling.samplers", |
| 36 | + "sampling.grammar", "sampling.grammar_lazy", "sampling.grammar_triggers", |
| 37 | + "speculative.devices", "speculative.n_ctx", "speculative.n_max", "speculative.n_min", |
| 38 | + "speculative.n_gpu_layers", "speculative.p_split", "speculative.p_min", |
| 39 | + "speculative.model.path", "speculative.model.url", "speculative.model.hf_repo", |
| 40 | + "speculative.model.hf_file", "speculative.tensor_buft_overrides", |
| 41 | + "speculative.cpuparams", "speculative.cpuparams_batch", |
| 42 | + "vocoder.model.path", "vocoder.model.url", "vocoder.model.hf_repo", |
| 43 | + "vocoder.model.hf_file", "vocoder.speaker_file", "vocoder.use_guide_tokens" |
| 44 | + }; |
| 45 | +} |
| 46 | + |
| 47 | +std::string common_yaml_valid_keys_help() { |
| 48 | + const auto keys = get_valid_keys(); |
| 49 | + std::ostringstream ss; |
| 50 | + bool first = true; |
| 51 | + for (const auto & key : keys) { |
| 52 | + if (!first) ss << ", "; |
| 53 | + ss << key; |
| 54 | + first = false; |
| 55 | + } |
| 56 | + return ss.str(); |
| 57 | +} |
| 58 | + |
| 59 | +static std::string resolve_path(const std::string & path, const fs::path & yaml_dir) { |
| 60 | + fs::path p(path); |
| 61 | + if (p.is_absolute()) { |
| 62 | + return path; |
| 63 | + } |
| 64 | + return fs::weakly_canonical(yaml_dir / p).string(); |
| 65 | +} |
| 66 | + |
| 67 | +static void collect_keys(const YAML::Node & node, const std::string & prefix, std::set<std::string> & found_keys) { |
| 68 | + if (node.IsMap()) { |
| 69 | + for (const auto & kv : node) { |
| 70 | + std::string key = kv.first.as<std::string>(); |
| 71 | + std::string full_key = prefix.empty() ? key : prefix + "." + key; |
| 72 | + found_keys.insert(full_key); |
| 73 | + collect_keys(kv.second, full_key, found_keys); |
| 74 | + } |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +static void validate_keys(const YAML::Node & root) { |
| 79 | + std::set<std::string> found_keys; |
| 80 | + collect_keys(root, "", found_keys); |
| 81 | + |
| 82 | + const auto valid_keys = get_valid_keys(); |
| 83 | + std::vector<std::string> unknown_keys; |
| 84 | + |
| 85 | + for (const auto & key : found_keys) { |
| 86 | + if (valid_keys.find(key) == valid_keys.end()) { |
| 87 | + bool is_parent = false; |
| 88 | + for (const auto & valid_key : valid_keys) { |
| 89 | + if (valid_key.find(key + ".") == 0) { |
| 90 | + is_parent = true; |
| 91 | + break; |
| 92 | + } |
| 93 | + } |
| 94 | + if (!is_parent) { |
| 95 | + unknown_keys.push_back(key); |
| 96 | + } |
| 97 | + } |
| 98 | + } |
| 99 | + |
| 100 | + if (!unknown_keys.empty()) { |
| 101 | + std::ostringstream ss; |
| 102 | + ss << "Unknown YAML keys: "; |
| 103 | + for (size_t i = 0; i < unknown_keys.size(); ++i) { |
| 104 | + if (i > 0) ss << ", "; |
| 105 | + ss << unknown_keys[i]; |
| 106 | + } |
| 107 | + ss << "; valid keys are: " << common_yaml_valid_keys_help(); |
| 108 | + throw std::invalid_argument(ss.str()); |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +static ggml_type parse_ggml_type(const std::string & type_str) { |
| 113 | + if (type_str == "f32") return GGML_TYPE_F32; |
| 114 | + if (type_str == "f16") return GGML_TYPE_F16; |
| 115 | + if (type_str == "bf16") return GGML_TYPE_BF16; |
| 116 | + if (type_str == "q8_0") return GGML_TYPE_Q8_0; |
| 117 | + if (type_str == "q4_0") return GGML_TYPE_Q4_0; |
| 118 | + if (type_str == "q4_1") return GGML_TYPE_Q4_1; |
| 119 | + if (type_str == "iq4_nl") return GGML_TYPE_IQ4_NL; |
| 120 | + if (type_str == "q5_0") return GGML_TYPE_Q5_0; |
| 121 | + if (type_str == "q5_1") return GGML_TYPE_Q5_1; |
| 122 | + throw std::invalid_argument("Unknown ggml_type: " + type_str); |
| 123 | +} |
| 124 | + |
| 125 | +static enum llama_split_mode parse_split_mode(const std::string & mode_str) { |
| 126 | + if (mode_str == "none") return LLAMA_SPLIT_MODE_NONE; |
| 127 | + if (mode_str == "layer") return LLAMA_SPLIT_MODE_LAYER; |
| 128 | + if (mode_str == "row") return LLAMA_SPLIT_MODE_ROW; |
| 129 | + throw std::invalid_argument("Unknown split_mode: " + mode_str); |
| 130 | +} |
| 131 | + |
| 132 | +static enum llama_pooling_type parse_pooling_type(const std::string & type_str) { |
| 133 | + if (type_str == "unspecified") return LLAMA_POOLING_TYPE_UNSPECIFIED; |
| 134 | + if (type_str == "none") return LLAMA_POOLING_TYPE_NONE; |
| 135 | + if (type_str == "mean") return LLAMA_POOLING_TYPE_MEAN; |
| 136 | + if (type_str == "cls") return LLAMA_POOLING_TYPE_CLS; |
| 137 | + if (type_str == "last") return LLAMA_POOLING_TYPE_LAST; |
| 138 | + if (type_str == "rank") return LLAMA_POOLING_TYPE_RANK; |
| 139 | + throw std::invalid_argument("Unknown pooling_type: " + type_str); |
| 140 | +} |
| 141 | + |
| 142 | +static enum llama_attention_type parse_attention_type(const std::string & type_str) { |
| 143 | + if (type_str == "unspecified") return LLAMA_ATTENTION_TYPE_UNSPECIFIED; |
| 144 | + if (type_str == "causal") return LLAMA_ATTENTION_TYPE_CAUSAL; |
| 145 | + if (type_str == "non_causal") return LLAMA_ATTENTION_TYPE_NON_CAUSAL; |
| 146 | + throw std::invalid_argument("Unknown attention_type: " + type_str); |
| 147 | +} |
| 148 | + |
| 149 | +static enum llama_flash_attn_type parse_flash_attn_type(const std::string & type_str) { |
| 150 | + if (type_str == "auto") return LLAMA_FLASH_ATTN_TYPE_AUTO; |
| 151 | + if (type_str == "disabled") return LLAMA_FLASH_ATTN_TYPE_DISABLED; |
| 152 | + if (type_str == "enabled") return LLAMA_FLASH_ATTN_TYPE_ENABLED; |
| 153 | + throw std::invalid_argument("Unknown flash_attn_type: " + type_str); |
| 154 | +} |
| 155 | + |
| 156 | +static ggml_numa_strategy parse_numa_strategy(const std::string & strategy_str) { |
| 157 | + if (strategy_str == "disabled") return GGML_NUMA_STRATEGY_DISABLED; |
| 158 | + if (strategy_str == "distribute") return GGML_NUMA_STRATEGY_DISTRIBUTE; |
| 159 | + if (strategy_str == "isolate") return GGML_NUMA_STRATEGY_ISOLATE; |
| 160 | + if (strategy_str == "numactl") return GGML_NUMA_STRATEGY_NUMACTL; |
| 161 | + if (strategy_str == "mirror") return GGML_NUMA_STRATEGY_MIRROR; |
| 162 | + throw std::invalid_argument("Unknown numa_strategy: " + strategy_str); |
| 163 | +} |
| 164 | + |
| 165 | +static common_conversation_mode parse_conversation_mode(const std::string & mode_str) { |
| 166 | + if (mode_str == "auto") return COMMON_CONVERSATION_MODE_AUTO; |
| 167 | + if (mode_str == "enabled") return COMMON_CONVERSATION_MODE_ENABLED; |
| 168 | + if (mode_str == "disabled") return COMMON_CONVERSATION_MODE_DISABLED; |
| 169 | + throw std::invalid_argument("Unknown conversation_mode: " + mode_str); |
| 170 | +} |
| 171 | + |
| 172 | +bool common_load_yaml_config(const std::string & path, common_params & params) { |
| 173 | + try { |
| 174 | + YAML::Node root = YAML::LoadFile(path); |
| 175 | + |
| 176 | + validate_keys(root); |
| 177 | + |
| 178 | + fs::path yaml_dir = fs::absolute(path).parent_path(); |
| 179 | + |
| 180 | + if (root["model"]) { |
| 181 | + auto model = root["model"]; |
| 182 | + if (model["path"]) { |
| 183 | + params.model.path = resolve_path(model["path"].as<std::string>(), yaml_dir); |
| 184 | + } |
| 185 | + if (model["url"]) { |
| 186 | + params.model.url = model["url"].as<std::string>(); |
| 187 | + } |
| 188 | + if (model["hf_repo"]) { |
| 189 | + params.model.hf_repo = model["hf_repo"].as<std::string>(); |
| 190 | + } |
| 191 | + if (model["hf_file"]) { |
| 192 | + params.model.hf_file = model["hf_file"].as<std::string>(); |
| 193 | + } |
| 194 | + } |
| 195 | + |
| 196 | + if (root["model_alias"]) params.model_alias = root["model_alias"].as<std::string>(); |
| 197 | + if (root["hf_token"]) params.hf_token = root["hf_token"].as<std::string>(); |
| 198 | + if (root["prompt"]) params.prompt = root["prompt"].as<std::string>(); |
| 199 | + if (root["system_prompt"]) params.system_prompt = root["system_prompt"].as<std::string>(); |
| 200 | + if (root["prompt_file"]) { |
| 201 | + params.prompt_file = resolve_path(root["prompt_file"].as<std::string>(), yaml_dir); |
| 202 | + } |
| 203 | + |
| 204 | + if (root["n_predict"]) params.n_predict = root["n_predict"].as<int32_t>(); |
| 205 | + if (root["n_ctx"]) params.n_ctx = root["n_ctx"].as<int32_t>(); |
| 206 | + if (root["n_batch"]) params.n_batch = root["n_batch"].as<int32_t>(); |
| 207 | + if (root["n_ubatch"]) params.n_ubatch = root["n_ubatch"].as<int32_t>(); |
| 208 | + if (root["n_keep"]) params.n_keep = root["n_keep"].as<int32_t>(); |
| 209 | + if (root["n_chunks"]) params.n_chunks = root["n_chunks"].as<int32_t>(); |
| 210 | + if (root["n_parallel"]) params.n_parallel = root["n_parallel"].as<int32_t>(); |
| 211 | + if (root["n_sequences"]) params.n_sequences = root["n_sequences"].as<int32_t>(); |
| 212 | + if (root["grp_attn_n"]) params.grp_attn_n = root["grp_attn_n"].as<int32_t>(); |
| 213 | + if (root["grp_attn_w"]) params.grp_attn_w = root["grp_attn_w"].as<int32_t>(); |
| 214 | + if (root["n_print"]) params.n_print = root["n_print"].as<int32_t>(); |
| 215 | + |
| 216 | + if (root["rope_freq_base"]) params.rope_freq_base = root["rope_freq_base"].as<float>(); |
| 217 | + if (root["rope_freq_scale"]) params.rope_freq_scale = root["rope_freq_scale"].as<float>(); |
| 218 | + if (root["yarn_ext_factor"]) params.yarn_ext_factor = root["yarn_ext_factor"].as<float>(); |
| 219 | + if (root["yarn_attn_factor"]) params.yarn_attn_factor = root["yarn_attn_factor"].as<float>(); |
| 220 | + if (root["yarn_beta_fast"]) params.yarn_beta_fast = root["yarn_beta_fast"].as<float>(); |
| 221 | + if (root["yarn_beta_slow"]) params.yarn_beta_slow = root["yarn_beta_slow"].as<float>(); |
| 222 | + if (root["yarn_orig_ctx"]) params.yarn_orig_ctx = root["yarn_orig_ctx"].as<int32_t>(); |
| 223 | + |
| 224 | + if (root["n_gpu_layers"]) params.n_gpu_layers = root["n_gpu_layers"].as<int32_t>(); |
| 225 | + if (root["main_gpu"]) params.main_gpu = root["main_gpu"].as<int32_t>(); |
| 226 | + |
| 227 | + if (root["split_mode"]) { |
| 228 | + params.split_mode = parse_split_mode(root["split_mode"].as<std::string>()); |
| 229 | + } |
| 230 | + if (root["pooling_type"]) { |
| 231 | + params.pooling_type = parse_pooling_type(root["pooling_type"].as<std::string>()); |
| 232 | + } |
| 233 | + if (root["attention_type"]) { |
| 234 | + params.attention_type = parse_attention_type(root["attention_type"].as<std::string>()); |
| 235 | + } |
| 236 | + if (root["flash_attn_type"]) { |
| 237 | + params.flash_attn_type = parse_flash_attn_type(root["flash_attn_type"].as<std::string>()); |
| 238 | + } |
| 239 | + if (root["numa"]) { |
| 240 | + params.numa = parse_numa_strategy(root["numa"].as<std::string>()); |
| 241 | + } |
| 242 | + if (root["conversation_mode"]) { |
| 243 | + params.conversation_mode = parse_conversation_mode(root["conversation_mode"].as<std::string>()); |
| 244 | + } |
| 245 | + |
| 246 | + if (root["use_mmap"]) params.use_mmap = root["use_mmap"].as<bool>(); |
| 247 | + if (root["use_mlock"]) params.use_mlock = root["use_mlock"].as<bool>(); |
| 248 | + if (root["verbose_prompt"]) params.verbose_prompt = root["verbose_prompt"].as<bool>(); |
| 249 | + if (root["display_prompt"]) params.display_prompt = root["display_prompt"].as<bool>(); |
| 250 | + if (root["no_kv_offload"]) params.no_kv_offload = root["no_kv_offload"].as<bool>(); |
| 251 | + if (root["warmup"]) params.warmup = root["warmup"].as<bool>(); |
| 252 | + if (root["check_tensors"]) params.check_tensors = root["check_tensors"].as<bool>(); |
| 253 | + if (root["no_op_offload"]) params.no_op_offload = root["no_op_offload"].as<bool>(); |
| 254 | + if (root["no_extra_bufts"]) params.no_extra_bufts = root["no_extra_bufts"].as<bool>(); |
| 255 | + if (root["simple_io"]) params.simple_io = root["simple_io"].as<bool>(); |
| 256 | + if (root["interactive"]) params.interactive = root["interactive"].as<bool>(); |
| 257 | + if (root["interactive_first"]) params.interactive_first = root["interactive_first"].as<bool>(); |
| 258 | + |
| 259 | + if (root["input_prefix"]) params.input_prefix = root["input_prefix"].as<std::string>(); |
| 260 | + if (root["input_suffix"]) params.input_suffix = root["input_suffix"].as<std::string>(); |
| 261 | + if (root["logits_file"]) { |
| 262 | + params.logits_file = resolve_path(root["logits_file"].as<std::string>(), yaml_dir); |
| 263 | + } |
| 264 | + if (root["path_prompt_cache"]) { |
| 265 | + params.path_prompt_cache = resolve_path(root["path_prompt_cache"].as<std::string>(), yaml_dir); |
| 266 | + } |
| 267 | + |
| 268 | + if (root["cache_type_k"]) { |
| 269 | + params.cache_type_k = parse_ggml_type(root["cache_type_k"].as<std::string>()); |
| 270 | + } |
| 271 | + if (root["cache_type_v"]) { |
| 272 | + params.cache_type_v = parse_ggml_type(root["cache_type_v"].as<std::string>()); |
| 273 | + } |
| 274 | + |
| 275 | + if (root["antiprompt"]) { |
| 276 | + params.antiprompt.clear(); |
| 277 | + for (const auto & item : root["antiprompt"]) { |
| 278 | + params.antiprompt.push_back(item.as<std::string>()); |
| 279 | + } |
| 280 | + } |
| 281 | + |
| 282 | + if (root["in_files"]) { |
| 283 | + params.in_files.clear(); |
| 284 | + for (const auto & item : root["in_files"]) { |
| 285 | + params.in_files.push_back(resolve_path(item.as<std::string>(), yaml_dir)); |
| 286 | + } |
| 287 | + } |
| 288 | + |
| 289 | + if (root["image"]) { |
| 290 | + params.image.clear(); |
| 291 | + for (const auto & item : root["image"]) { |
| 292 | + params.image.push_back(resolve_path(item.as<std::string>(), yaml_dir)); |
| 293 | + } |
| 294 | + } |
| 295 | + |
| 296 | + if (root["seed"]) { |
| 297 | + params.sampling.seed = root["seed"].as<uint32_t>(); |
| 298 | + } |
| 299 | + |
| 300 | + if (root["sampling"]) { |
| 301 | + auto sampling = root["sampling"]; |
| 302 | + if (sampling["seed"]) params.sampling.seed = sampling["seed"].as<uint32_t>(); |
| 303 | + if (sampling["n_prev"]) params.sampling.n_prev = sampling["n_prev"].as<int32_t>(); |
| 304 | + if (sampling["n_probs"]) params.sampling.n_probs = sampling["n_probs"].as<int32_t>(); |
| 305 | + if (sampling["min_keep"]) params.sampling.min_keep = sampling["min_keep"].as<int32_t>(); |
| 306 | + if (sampling["top_k"]) params.sampling.top_k = sampling["top_k"].as<int32_t>(); |
| 307 | + if (sampling["top_p"]) params.sampling.top_p = sampling["top_p"].as<float>(); |
| 308 | + if (sampling["min_p"]) params.sampling.min_p = sampling["min_p"].as<float>(); |
| 309 | + if (sampling["xtc_probability"]) params.sampling.xtc_probability = sampling["xtc_probability"].as<float>(); |
| 310 | + if (sampling["xtc_threshold"]) params.sampling.xtc_threshold = sampling["xtc_threshold"].as<float>(); |
| 311 | + if (sampling["typ_p"]) params.sampling.typ_p = sampling["typ_p"].as<float>(); |
| 312 | + if (sampling["temp"]) params.sampling.temp = sampling["temp"].as<float>(); |
| 313 | + if (sampling["dynatemp_range"]) params.sampling.dynatemp_range = sampling["dynatemp_range"].as<float>(); |
| 314 | + if (sampling["dynatemp_exponent"]) params.sampling.dynatemp_exponent = sampling["dynatemp_exponent"].as<float>(); |
| 315 | + if (sampling["penalty_last_n"]) params.sampling.penalty_last_n = sampling["penalty_last_n"].as<int32_t>(); |
| 316 | + if (sampling["penalty_repeat"]) params.sampling.penalty_repeat = sampling["penalty_repeat"].as<float>(); |
| 317 | + if (sampling["penalty_freq"]) params.sampling.penalty_freq = sampling["penalty_freq"].as<float>(); |
| 318 | + if (sampling["penalty_present"]) params.sampling.penalty_present = sampling["penalty_present"].as<float>(); |
| 319 | + if (sampling["dry_multiplier"]) params.sampling.dry_multiplier = sampling["dry_multiplier"].as<float>(); |
| 320 | + if (sampling["dry_base"]) params.sampling.dry_base = sampling["dry_base"].as<float>(); |
| 321 | + if (sampling["dry_allowed_length"]) params.sampling.dry_allowed_length = sampling["dry_allowed_length"].as<int32_t>(); |
| 322 | + if (sampling["dry_penalty_last_n"]) params.sampling.dry_penalty_last_n = sampling["dry_penalty_last_n"].as<int32_t>(); |
| 323 | + if (sampling["mirostat"]) params.sampling.mirostat = sampling["mirostat"].as<int32_t>(); |
| 324 | + if (sampling["mirostat_tau"]) params.sampling.mirostat_tau = sampling["mirostat_tau"].as<float>(); |
| 325 | + if (sampling["mirostat_eta"]) params.sampling.mirostat_eta = sampling["mirostat_eta"].as<float>(); |
| 326 | + if (sampling["top_n_sigma"]) params.sampling.top_n_sigma = sampling["top_n_sigma"].as<float>(); |
| 327 | + if (sampling["ignore_eos"]) params.sampling.ignore_eos = sampling["ignore_eos"].as<bool>(); |
| 328 | + if (sampling["no_perf"]) params.sampling.no_perf = sampling["no_perf"].as<bool>(); |
| 329 | + if (sampling["timing_per_token"]) params.sampling.timing_per_token = sampling["timing_per_token"].as<bool>(); |
| 330 | + if (sampling["grammar"]) params.sampling.grammar = sampling["grammar"].as<std::string>(); |
| 331 | + if (sampling["grammar_lazy"]) params.sampling.grammar_lazy = sampling["grammar_lazy"].as<bool>(); |
| 332 | + } |
| 333 | + |
| 334 | + return true; |
| 335 | + } catch (const YAML::Exception & e) { |
| 336 | + throw std::invalid_argument("YAML parsing error: " + std::string(e.what())); |
| 337 | + } catch (const std::exception & e) { |
| 338 | + throw std::invalid_argument("Config loading error: " + std::string(e.what())); |
| 339 | + } |
| 340 | +} |
0 commit comments