Skip to content

Commit 5e0f94f

Browse files
feat(common): Add YAML config loader with yaml-cpp dependency
- Add common/config.h and common/config.cpp for YAML configuration loading - Implement common_load_yaml_config() with validation and path resolution - Add yaml-cpp dependency via FetchContent in common/CMakeLists.txt - Support nested config structure (model, sampling, speculative, vocoder) - Reject unknown keys with descriptive error messages - Resolve relative paths relative to YAML file directory Co-Authored-By: Jaime Mizrachi <[email protected]>
1 parent 661ae31 commit 5e0f94f

File tree

3 files changed

+359
-1
lines changed

3 files changed

+359
-1
lines changed

common/CMakeLists.txt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22

33
find_package(Threads REQUIRED)
44

5+
find_package(yaml-cpp QUIET)
6+
if (NOT yaml-cpp_FOUND)
7+
include(FetchContent)
8+
FetchContent_Declare(yaml-cpp
9+
GIT_REPOSITORY https://github.com/jbeder/yaml-cpp.git
10+
GIT_TAG 0.8.0)
11+
FetchContent_MakeAvailable(yaml-cpp)
12+
endif()
13+
514
llama_add_compile_flags()
615

716
# Build info header
@@ -54,6 +63,8 @@ add_library(${TARGET} STATIC
5463
chat.h
5564
common.cpp
5665
common.h
66+
config.cpp
67+
config.h
5768
console.cpp
5869
console.h
5970
json-partial.cpp
@@ -135,7 +146,7 @@ endif ()
135146

136147
target_include_directories(${TARGET} PUBLIC . ../vendor)
137148
target_compile_features (${TARGET} PUBLIC cxx_std_17)
138-
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
149+
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} yaml-cpp PUBLIC llama Threads::Threads)
139150

140151

141152
#

common/config.cpp

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
#include "config.h"
2+
#include "log.h"
3+
4+
#include <yaml-cpp/yaml.h>
5+
#include <filesystem>
6+
#include <set>
7+
#include <sstream>
8+
#include <stdexcept>
9+
10+
namespace fs = std::filesystem;
11+
12+
static std::set<std::string> get_valid_keys() {
13+
return {
14+
"model.path", "model.url", "model.hf_repo", "model.hf_file",
15+
"model_alias", "hf_token", "prompt", "system_prompt", "prompt_file",
16+
"n_predict", "n_ctx", "n_batch", "n_ubatch", "n_keep", "n_chunks",
17+
"n_parallel", "n_sequences", "grp_attn_n", "grp_attn_w", "n_print",
18+
"rope_freq_base", "rope_freq_scale", "yarn_ext_factor", "yarn_attn_factor",
19+
"yarn_beta_fast", "yarn_beta_slow", "yarn_orig_ctx",
20+
"n_gpu_layers", "main_gpu", "split_mode", "pooling_type", "attention_type",
21+
"flash_attn_type", "numa", "use_mmap", "use_mlock", "verbose_prompt",
22+
"display_prompt", "no_kv_offload", "warmup", "check_tensors", "no_op_offload",
23+
"no_extra_bufts", "cache_type_k", "cache_type_v", "conversation_mode",
24+
"simple_io", "interactive", "interactive_first", "input_prefix", "input_suffix",
25+
"logits_file", "path_prompt_cache", "antiprompt", "in_files", "kv_overrides",
26+
"tensor_buft_overrides", "lora_adapters", "control_vectors", "image", "seed",
27+
"sampling.seed", "sampling.n_prev", "sampling.n_probs", "sampling.min_keep",
28+
"sampling.top_k", "sampling.top_p", "sampling.min_p", "sampling.xtc_probability",
29+
"sampling.xtc_threshold", "sampling.typ_p", "sampling.temp", "sampling.dynatemp_range",
30+
"sampling.dynatemp_exponent", "sampling.penalty_last_n", "sampling.penalty_repeat",
31+
"sampling.penalty_freq", "sampling.penalty_present", "sampling.dry_multiplier",
32+
"sampling.dry_base", "sampling.dry_allowed_length", "sampling.dry_penalty_last_n",
33+
"sampling.mirostat", "sampling.mirostat_tau", "sampling.mirostat_eta",
34+
"sampling.top_n_sigma", "sampling.ignore_eos", "sampling.no_perf",
35+
"sampling.timing_per_token", "sampling.dry_sequence_breakers", "sampling.samplers",
36+
"sampling.grammar", "sampling.grammar_lazy", "sampling.grammar_triggers",
37+
"speculative.devices", "speculative.n_ctx", "speculative.n_max", "speculative.n_min",
38+
"speculative.n_gpu_layers", "speculative.p_split", "speculative.p_min",
39+
"speculative.model.path", "speculative.model.url", "speculative.model.hf_repo",
40+
"speculative.model.hf_file", "speculative.tensor_buft_overrides",
41+
"speculative.cpuparams", "speculative.cpuparams_batch",
42+
"vocoder.model.path", "vocoder.model.url", "vocoder.model.hf_repo",
43+
"vocoder.model.hf_file", "vocoder.speaker_file", "vocoder.use_guide_tokens"
44+
};
45+
}
46+
47+
std::string common_yaml_valid_keys_help() {
48+
const auto keys = get_valid_keys();
49+
std::ostringstream ss;
50+
bool first = true;
51+
for (const auto & key : keys) {
52+
if (!first) ss << ", ";
53+
ss << key;
54+
first = false;
55+
}
56+
return ss.str();
57+
}
58+
59+
static std::string resolve_path(const std::string & path, const fs::path & yaml_dir) {
60+
fs::path p(path);
61+
if (p.is_absolute()) {
62+
return path;
63+
}
64+
return fs::weakly_canonical(yaml_dir / p).string();
65+
}
66+
67+
static void collect_keys(const YAML::Node & node, const std::string & prefix, std::set<std::string> & found_keys) {
68+
if (node.IsMap()) {
69+
for (const auto & kv : node) {
70+
std::string key = kv.first.as<std::string>();
71+
std::string full_key = prefix.empty() ? key : prefix + "." + key;
72+
found_keys.insert(full_key);
73+
collect_keys(kv.second, full_key, found_keys);
74+
}
75+
}
76+
}
77+
78+
static void validate_keys(const YAML::Node & root) {
79+
std::set<std::string> found_keys;
80+
collect_keys(root, "", found_keys);
81+
82+
const auto valid_keys = get_valid_keys();
83+
std::vector<std::string> unknown_keys;
84+
85+
for (const auto & key : found_keys) {
86+
if (valid_keys.find(key) == valid_keys.end()) {
87+
bool is_parent = false;
88+
for (const auto & valid_key : valid_keys) {
89+
if (valid_key.find(key + ".") == 0) {
90+
is_parent = true;
91+
break;
92+
}
93+
}
94+
if (!is_parent) {
95+
unknown_keys.push_back(key);
96+
}
97+
}
98+
}
99+
100+
if (!unknown_keys.empty()) {
101+
std::ostringstream ss;
102+
ss << "Unknown YAML keys: ";
103+
for (size_t i = 0; i < unknown_keys.size(); ++i) {
104+
if (i > 0) ss << ", ";
105+
ss << unknown_keys[i];
106+
}
107+
ss << "; valid keys are: " << common_yaml_valid_keys_help();
108+
throw std::invalid_argument(ss.str());
109+
}
110+
}
111+
112+
static ggml_type parse_ggml_type(const std::string & type_str) {
113+
if (type_str == "f32") return GGML_TYPE_F32;
114+
if (type_str == "f16") return GGML_TYPE_F16;
115+
if (type_str == "bf16") return GGML_TYPE_BF16;
116+
if (type_str == "q8_0") return GGML_TYPE_Q8_0;
117+
if (type_str == "q4_0") return GGML_TYPE_Q4_0;
118+
if (type_str == "q4_1") return GGML_TYPE_Q4_1;
119+
if (type_str == "iq4_nl") return GGML_TYPE_IQ4_NL;
120+
if (type_str == "q5_0") return GGML_TYPE_Q5_0;
121+
if (type_str == "q5_1") return GGML_TYPE_Q5_1;
122+
throw std::invalid_argument("Unknown ggml_type: " + type_str);
123+
}
124+
125+
static enum llama_split_mode parse_split_mode(const std::string & mode_str) {
126+
if (mode_str == "none") return LLAMA_SPLIT_MODE_NONE;
127+
if (mode_str == "layer") return LLAMA_SPLIT_MODE_LAYER;
128+
if (mode_str == "row") return LLAMA_SPLIT_MODE_ROW;
129+
throw std::invalid_argument("Unknown split_mode: " + mode_str);
130+
}
131+
132+
static enum llama_pooling_type parse_pooling_type(const std::string & type_str) {
133+
if (type_str == "unspecified") return LLAMA_POOLING_TYPE_UNSPECIFIED;
134+
if (type_str == "none") return LLAMA_POOLING_TYPE_NONE;
135+
if (type_str == "mean") return LLAMA_POOLING_TYPE_MEAN;
136+
if (type_str == "cls") return LLAMA_POOLING_TYPE_CLS;
137+
if (type_str == "last") return LLAMA_POOLING_TYPE_LAST;
138+
if (type_str == "rank") return LLAMA_POOLING_TYPE_RANK;
139+
throw std::invalid_argument("Unknown pooling_type: " + type_str);
140+
}
141+
142+
static enum llama_attention_type parse_attention_type(const std::string & type_str) {
143+
if (type_str == "unspecified") return LLAMA_ATTENTION_TYPE_UNSPECIFIED;
144+
if (type_str == "causal") return LLAMA_ATTENTION_TYPE_CAUSAL;
145+
if (type_str == "non_causal") return LLAMA_ATTENTION_TYPE_NON_CAUSAL;
146+
throw std::invalid_argument("Unknown attention_type: " + type_str);
147+
}
148+
149+
static enum llama_flash_attn_type parse_flash_attn_type(const std::string & type_str) {
150+
if (type_str == "auto") return LLAMA_FLASH_ATTN_TYPE_AUTO;
151+
if (type_str == "disabled") return LLAMA_FLASH_ATTN_TYPE_DISABLED;
152+
if (type_str == "enabled") return LLAMA_FLASH_ATTN_TYPE_ENABLED;
153+
throw std::invalid_argument("Unknown flash_attn_type: " + type_str);
154+
}
155+
156+
static ggml_numa_strategy parse_numa_strategy(const std::string & strategy_str) {
157+
if (strategy_str == "disabled") return GGML_NUMA_STRATEGY_DISABLED;
158+
if (strategy_str == "distribute") return GGML_NUMA_STRATEGY_DISTRIBUTE;
159+
if (strategy_str == "isolate") return GGML_NUMA_STRATEGY_ISOLATE;
160+
if (strategy_str == "numactl") return GGML_NUMA_STRATEGY_NUMACTL;
161+
if (strategy_str == "mirror") return GGML_NUMA_STRATEGY_MIRROR;
162+
throw std::invalid_argument("Unknown numa_strategy: " + strategy_str);
163+
}
164+
165+
static common_conversation_mode parse_conversation_mode(const std::string & mode_str) {
166+
if (mode_str == "auto") return COMMON_CONVERSATION_MODE_AUTO;
167+
if (mode_str == "enabled") return COMMON_CONVERSATION_MODE_ENABLED;
168+
if (mode_str == "disabled") return COMMON_CONVERSATION_MODE_DISABLED;
169+
throw std::invalid_argument("Unknown conversation_mode: " + mode_str);
170+
}
171+
172+
bool common_load_yaml_config(const std::string & path, common_params & params) {
173+
try {
174+
YAML::Node root = YAML::LoadFile(path);
175+
176+
validate_keys(root);
177+
178+
fs::path yaml_dir = fs::absolute(path).parent_path();
179+
180+
if (root["model"]) {
181+
auto model = root["model"];
182+
if (model["path"]) {
183+
params.model.path = resolve_path(model["path"].as<std::string>(), yaml_dir);
184+
}
185+
if (model["url"]) {
186+
params.model.url = model["url"].as<std::string>();
187+
}
188+
if (model["hf_repo"]) {
189+
params.model.hf_repo = model["hf_repo"].as<std::string>();
190+
}
191+
if (model["hf_file"]) {
192+
params.model.hf_file = model["hf_file"].as<std::string>();
193+
}
194+
}
195+
196+
if (root["model_alias"]) params.model_alias = root["model_alias"].as<std::string>();
197+
if (root["hf_token"]) params.hf_token = root["hf_token"].as<std::string>();
198+
if (root["prompt"]) params.prompt = root["prompt"].as<std::string>();
199+
if (root["system_prompt"]) params.system_prompt = root["system_prompt"].as<std::string>();
200+
if (root["prompt_file"]) {
201+
params.prompt_file = resolve_path(root["prompt_file"].as<std::string>(), yaml_dir);
202+
}
203+
204+
if (root["n_predict"]) params.n_predict = root["n_predict"].as<int32_t>();
205+
if (root["n_ctx"]) params.n_ctx = root["n_ctx"].as<int32_t>();
206+
if (root["n_batch"]) params.n_batch = root["n_batch"].as<int32_t>();
207+
if (root["n_ubatch"]) params.n_ubatch = root["n_ubatch"].as<int32_t>();
208+
if (root["n_keep"]) params.n_keep = root["n_keep"].as<int32_t>();
209+
if (root["n_chunks"]) params.n_chunks = root["n_chunks"].as<int32_t>();
210+
if (root["n_parallel"]) params.n_parallel = root["n_parallel"].as<int32_t>();
211+
if (root["n_sequences"]) params.n_sequences = root["n_sequences"].as<int32_t>();
212+
if (root["grp_attn_n"]) params.grp_attn_n = root["grp_attn_n"].as<int32_t>();
213+
if (root["grp_attn_w"]) params.grp_attn_w = root["grp_attn_w"].as<int32_t>();
214+
if (root["n_print"]) params.n_print = root["n_print"].as<int32_t>();
215+
216+
if (root["rope_freq_base"]) params.rope_freq_base = root["rope_freq_base"].as<float>();
217+
if (root["rope_freq_scale"]) params.rope_freq_scale = root["rope_freq_scale"].as<float>();
218+
if (root["yarn_ext_factor"]) params.yarn_ext_factor = root["yarn_ext_factor"].as<float>();
219+
if (root["yarn_attn_factor"]) params.yarn_attn_factor = root["yarn_attn_factor"].as<float>();
220+
if (root["yarn_beta_fast"]) params.yarn_beta_fast = root["yarn_beta_fast"].as<float>();
221+
if (root["yarn_beta_slow"]) params.yarn_beta_slow = root["yarn_beta_slow"].as<float>();
222+
if (root["yarn_orig_ctx"]) params.yarn_orig_ctx = root["yarn_orig_ctx"].as<int32_t>();
223+
224+
if (root["n_gpu_layers"]) params.n_gpu_layers = root["n_gpu_layers"].as<int32_t>();
225+
if (root["main_gpu"]) params.main_gpu = root["main_gpu"].as<int32_t>();
226+
227+
if (root["split_mode"]) {
228+
params.split_mode = parse_split_mode(root["split_mode"].as<std::string>());
229+
}
230+
if (root["pooling_type"]) {
231+
params.pooling_type = parse_pooling_type(root["pooling_type"].as<std::string>());
232+
}
233+
if (root["attention_type"]) {
234+
params.attention_type = parse_attention_type(root["attention_type"].as<std::string>());
235+
}
236+
if (root["flash_attn_type"]) {
237+
params.flash_attn_type = parse_flash_attn_type(root["flash_attn_type"].as<std::string>());
238+
}
239+
if (root["numa"]) {
240+
params.numa = parse_numa_strategy(root["numa"].as<std::string>());
241+
}
242+
if (root["conversation_mode"]) {
243+
params.conversation_mode = parse_conversation_mode(root["conversation_mode"].as<std::string>());
244+
}
245+
246+
if (root["use_mmap"]) params.use_mmap = root["use_mmap"].as<bool>();
247+
if (root["use_mlock"]) params.use_mlock = root["use_mlock"].as<bool>();
248+
if (root["verbose_prompt"]) params.verbose_prompt = root["verbose_prompt"].as<bool>();
249+
if (root["display_prompt"]) params.display_prompt = root["display_prompt"].as<bool>();
250+
if (root["no_kv_offload"]) params.no_kv_offload = root["no_kv_offload"].as<bool>();
251+
if (root["warmup"]) params.warmup = root["warmup"].as<bool>();
252+
if (root["check_tensors"]) params.check_tensors = root["check_tensors"].as<bool>();
253+
if (root["no_op_offload"]) params.no_op_offload = root["no_op_offload"].as<bool>();
254+
if (root["no_extra_bufts"]) params.no_extra_bufts = root["no_extra_bufts"].as<bool>();
255+
if (root["simple_io"]) params.simple_io = root["simple_io"].as<bool>();
256+
if (root["interactive"]) params.interactive = root["interactive"].as<bool>();
257+
if (root["interactive_first"]) params.interactive_first = root["interactive_first"].as<bool>();
258+
259+
if (root["input_prefix"]) params.input_prefix = root["input_prefix"].as<std::string>();
260+
if (root["input_suffix"]) params.input_suffix = root["input_suffix"].as<std::string>();
261+
if (root["logits_file"]) {
262+
params.logits_file = resolve_path(root["logits_file"].as<std::string>(), yaml_dir);
263+
}
264+
if (root["path_prompt_cache"]) {
265+
params.path_prompt_cache = resolve_path(root["path_prompt_cache"].as<std::string>(), yaml_dir);
266+
}
267+
268+
if (root["cache_type_k"]) {
269+
params.cache_type_k = parse_ggml_type(root["cache_type_k"].as<std::string>());
270+
}
271+
if (root["cache_type_v"]) {
272+
params.cache_type_v = parse_ggml_type(root["cache_type_v"].as<std::string>());
273+
}
274+
275+
if (root["antiprompt"]) {
276+
params.antiprompt.clear();
277+
for (const auto & item : root["antiprompt"]) {
278+
params.antiprompt.push_back(item.as<std::string>());
279+
}
280+
}
281+
282+
if (root["in_files"]) {
283+
params.in_files.clear();
284+
for (const auto & item : root["in_files"]) {
285+
params.in_files.push_back(resolve_path(item.as<std::string>(), yaml_dir));
286+
}
287+
}
288+
289+
if (root["image"]) {
290+
params.image.clear();
291+
for (const auto & item : root["image"]) {
292+
params.image.push_back(resolve_path(item.as<std::string>(), yaml_dir));
293+
}
294+
}
295+
296+
if (root["seed"]) {
297+
params.sampling.seed = root["seed"].as<uint32_t>();
298+
}
299+
300+
if (root["sampling"]) {
301+
auto sampling = root["sampling"];
302+
if (sampling["seed"]) params.sampling.seed = sampling["seed"].as<uint32_t>();
303+
if (sampling["n_prev"]) params.sampling.n_prev = sampling["n_prev"].as<int32_t>();
304+
if (sampling["n_probs"]) params.sampling.n_probs = sampling["n_probs"].as<int32_t>();
305+
if (sampling["min_keep"]) params.sampling.min_keep = sampling["min_keep"].as<int32_t>();
306+
if (sampling["top_k"]) params.sampling.top_k = sampling["top_k"].as<int32_t>();
307+
if (sampling["top_p"]) params.sampling.top_p = sampling["top_p"].as<float>();
308+
if (sampling["min_p"]) params.sampling.min_p = sampling["min_p"].as<float>();
309+
if (sampling["xtc_probability"]) params.sampling.xtc_probability = sampling["xtc_probability"].as<float>();
310+
if (sampling["xtc_threshold"]) params.sampling.xtc_threshold = sampling["xtc_threshold"].as<float>();
311+
if (sampling["typ_p"]) params.sampling.typ_p = sampling["typ_p"].as<float>();
312+
if (sampling["temp"]) params.sampling.temp = sampling["temp"].as<float>();
313+
if (sampling["dynatemp_range"]) params.sampling.dynatemp_range = sampling["dynatemp_range"].as<float>();
314+
if (sampling["dynatemp_exponent"]) params.sampling.dynatemp_exponent = sampling["dynatemp_exponent"].as<float>();
315+
if (sampling["penalty_last_n"]) params.sampling.penalty_last_n = sampling["penalty_last_n"].as<int32_t>();
316+
if (sampling["penalty_repeat"]) params.sampling.penalty_repeat = sampling["penalty_repeat"].as<float>();
317+
if (sampling["penalty_freq"]) params.sampling.penalty_freq = sampling["penalty_freq"].as<float>();
318+
if (sampling["penalty_present"]) params.sampling.penalty_present = sampling["penalty_present"].as<float>();
319+
if (sampling["dry_multiplier"]) params.sampling.dry_multiplier = sampling["dry_multiplier"].as<float>();
320+
if (sampling["dry_base"]) params.sampling.dry_base = sampling["dry_base"].as<float>();
321+
if (sampling["dry_allowed_length"]) params.sampling.dry_allowed_length = sampling["dry_allowed_length"].as<int32_t>();
322+
if (sampling["dry_penalty_last_n"]) params.sampling.dry_penalty_last_n = sampling["dry_penalty_last_n"].as<int32_t>();
323+
if (sampling["mirostat"]) params.sampling.mirostat = sampling["mirostat"].as<int32_t>();
324+
if (sampling["mirostat_tau"]) params.sampling.mirostat_tau = sampling["mirostat_tau"].as<float>();
325+
if (sampling["mirostat_eta"]) params.sampling.mirostat_eta = sampling["mirostat_eta"].as<float>();
326+
if (sampling["top_n_sigma"]) params.sampling.top_n_sigma = sampling["top_n_sigma"].as<float>();
327+
if (sampling["ignore_eos"]) params.sampling.ignore_eos = sampling["ignore_eos"].as<bool>();
328+
if (sampling["no_perf"]) params.sampling.no_perf = sampling["no_perf"].as<bool>();
329+
if (sampling["timing_per_token"]) params.sampling.timing_per_token = sampling["timing_per_token"].as<bool>();
330+
if (sampling["grammar"]) params.sampling.grammar = sampling["grammar"].as<std::string>();
331+
if (sampling["grammar_lazy"]) params.sampling.grammar_lazy = sampling["grammar_lazy"].as<bool>();
332+
}
333+
334+
return true;
335+
} catch (const YAML::Exception & e) {
336+
throw std::invalid_argument("YAML parsing error: " + std::string(e.what()));
337+
} catch (const std::exception & e) {
338+
throw std::invalid_argument("Config loading error: " + std::string(e.what()));
339+
}
340+
}

common/config.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include "common.h"
4+
#include <string>
5+
6+
bool common_load_yaml_config(const std::string & path, common_params & params);
7+
std::string common_yaml_valid_keys_help();

0 commit comments

Comments
 (0)