Skip to content

Commit 2f503ce

Browse files
Add YAML config file support to llama-cli
- Add yaml-cpp dependency as vendored library - Extend argument parsing to support --config option - Implement common_params_load_from_yaml() function - Add comprehensive test suite for YAML functionality - Maintain full backward compatibility with existing CLI - CLI arguments override YAML config values (correct precedence) - Add example config file and documentation - All existing tests pass with no regressions Co-Authored-By: Jaime Mizrachi <[email protected]>
1 parent 661ae31 commit 2f503ce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+12060
-1
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ endif()
179179
# build the library
180180
#
181181

182+
add_subdirectory(vendor/yaml-cpp)
183+
182184
add_subdirectory(src)
183185

184186
#

common/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ endif ()
135135

136136
target_include_directories(${TARGET} PUBLIC . ../vendor)
137137
target_compile_features (${TARGET} PUBLIC cxx_std_17)
138-
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
138+
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} yaml-cpp PUBLIC llama Threads::Threads)
139139

140140

141141
#

common/arg.cpp

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#define JSON_ASSERT GGML_ASSERT
2020
#include <nlohmann/json.hpp>
21+
#include <yaml-cpp/yaml.h>
2122

2223
#include <algorithm>
2324
#include <climits>
@@ -65,6 +66,169 @@ static void write_file(const std::string & fname, const std::string & content) {
6566
file.close();
6667
}
6768

69+
bool common_params_load_from_yaml(const std::string & config_file, common_params & params) {
70+
try {
71+
YAML::Node config = YAML::LoadFile(config_file);
72+
73+
// Model parameters
74+
if (config["model"]) {
75+
if (config["model"]["path"]) {
76+
params.model.path = config["model"]["path"].as<std::string>();
77+
}
78+
if (config["model"]["url"]) {
79+
params.model.url = config["model"]["url"].as<std::string>();
80+
}
81+
if (config["model"]["hf_repo"]) {
82+
params.model.hf_repo = config["model"]["hf_repo"].as<std::string>();
83+
}
84+
if (config["model"]["hf_file"]) {
85+
params.model.hf_file = config["model"]["hf_file"].as<std::string>();
86+
}
87+
}
88+
89+
// Basic parameters
90+
if (config["n_predict"]) params.n_predict = config["n_predict"].as<int32_t>();
91+
if (config["n_ctx"]) params.n_ctx = config["n_ctx"].as<int32_t>();
92+
if (config["n_batch"]) params.n_batch = config["n_batch"].as<int32_t>();
93+
if (config["n_ubatch"]) params.n_ubatch = config["n_ubatch"].as<int32_t>();
94+
if (config["n_keep"]) params.n_keep = config["n_keep"].as<int32_t>();
95+
if (config["n_chunks"]) params.n_chunks = config["n_chunks"].as<int32_t>();
96+
if (config["n_parallel"]) params.n_parallel = config["n_parallel"].as<int32_t>();
97+
if (config["n_sequences"]) params.n_sequences = config["n_sequences"].as<int32_t>();
98+
if (config["n_gpu_layers"]) params.n_gpu_layers = config["n_gpu_layers"].as<int32_t>();
99+
if (config["main_gpu"]) params.main_gpu = config["main_gpu"].as<int32_t>();
100+
if (config["verbosity"]) params.verbosity = config["verbosity"].as<int32_t>();
101+
102+
// String parameters
103+
if (config["prompt"]) params.prompt = config["prompt"].as<std::string>();
104+
if (config["system_prompt"]) params.system_prompt = config["system_prompt"].as<std::string>();
105+
if (config["prompt_file"]) params.prompt_file = config["prompt_file"].as<std::string>();
106+
if (config["input_prefix"]) params.input_prefix = config["input_prefix"].as<std::string>();
107+
if (config["input_suffix"]) params.input_suffix = config["input_suffix"].as<std::string>();
108+
if (config["hf_token"]) params.hf_token = config["hf_token"].as<std::string>();
109+
110+
// Float parameters
111+
if (config["rope_freq_base"]) params.rope_freq_base = config["rope_freq_base"].as<float>();
112+
if (config["rope_freq_scale"]) params.rope_freq_scale = config["rope_freq_scale"].as<float>();
113+
if (config["yarn_ext_factor"]) params.yarn_ext_factor = config["yarn_ext_factor"].as<float>();
114+
if (config["yarn_attn_factor"]) params.yarn_attn_factor = config["yarn_attn_factor"].as<float>();
115+
if (config["yarn_beta_fast"]) params.yarn_beta_fast = config["yarn_beta_fast"].as<float>();
116+
if (config["yarn_beta_slow"]) params.yarn_beta_slow = config["yarn_beta_slow"].as<float>();
117+
118+
// Boolean parameters
119+
if (config["interactive"]) params.interactive = config["interactive"].as<bool>();
120+
if (config["interactive_first"]) params.interactive_first = config["interactive_first"].as<bool>();
121+
if (config["conversation"]) {
122+
params.conversation_mode = config["conversation"].as<bool>() ?
123+
COMMON_CONVERSATION_MODE_ENABLED : COMMON_CONVERSATION_MODE_DISABLED;
124+
}
125+
if (config["use_color"]) params.use_color = config["use_color"].as<bool>();
126+
if (config["simple_io"]) params.simple_io = config["simple_io"].as<bool>();
127+
if (config["embedding"]) params.embedding = config["embedding"].as<bool>();
128+
if (config["escape"]) params.escape = config["escape"].as<bool>();
129+
if (config["multiline_input"]) params.multiline_input = config["multiline_input"].as<bool>();
130+
if (config["cont_batching"]) params.cont_batching = config["cont_batching"].as<bool>();
131+
if (config["flash_attn"]) {
132+
params.flash_attn_type = config["flash_attn"].as<bool>() ?
133+
LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
134+
}
135+
if (config["no_perf"]) params.no_perf = config["no_perf"].as<bool>();
136+
if (config["ctx_shift"]) params.ctx_shift = config["ctx_shift"].as<bool>();
137+
if (config["input_prefix_bos"]) params.input_prefix_bos = config["input_prefix_bos"].as<bool>();
138+
if (config["use_mmap"]) params.use_mmap = config["use_mmap"].as<bool>();
139+
if (config["use_mlock"]) params.use_mlock = config["use_mlock"].as<bool>();
140+
if (config["verbose_prompt"]) params.verbose_prompt = config["verbose_prompt"].as<bool>();
141+
if (config["display_prompt"]) params.display_prompt = config["display_prompt"].as<bool>();
142+
if (config["no_kv_offload"]) params.no_kv_offload = config["no_kv_offload"].as<bool>();
143+
if (config["warmup"]) params.warmup = config["warmup"].as<bool>();
144+
if (config["check_tensors"]) params.check_tensors = config["check_tensors"].as<bool>();
145+
146+
// CPU parameters
147+
if (config["cpuparams"]) {
148+
const auto & cpu_config = config["cpuparams"];
149+
if (cpu_config["n_threads"]) params.cpuparams.n_threads = cpu_config["n_threads"].as<int>();
150+
if (cpu_config["strict_cpu"]) params.cpuparams.strict_cpu = cpu_config["strict_cpu"].as<bool>();
151+
if (cpu_config["poll"]) params.cpuparams.poll = cpu_config["poll"].as<uint32_t>();
152+
}
153+
154+
// Sampling parameters
155+
if (config["sampling"]) {
156+
const auto & sampling_config = config["sampling"];
157+
if (sampling_config["seed"]) params.sampling.seed = sampling_config["seed"].as<uint32_t>();
158+
if (sampling_config["n_prev"]) params.sampling.n_prev = sampling_config["n_prev"].as<int32_t>();
159+
if (sampling_config["n_probs"]) params.sampling.n_probs = sampling_config["n_probs"].as<int32_t>();
160+
if (sampling_config["min_keep"]) params.sampling.min_keep = sampling_config["min_keep"].as<int32_t>();
161+
if (sampling_config["top_k"]) params.sampling.top_k = sampling_config["top_k"].as<int32_t>();
162+
if (sampling_config["top_p"]) params.sampling.top_p = sampling_config["top_p"].as<float>();
163+
if (sampling_config["min_p"]) params.sampling.min_p = sampling_config["min_p"].as<float>();
164+
if (sampling_config["xtc_probability"]) params.sampling.xtc_probability = sampling_config["xtc_probability"].as<float>();
165+
if (sampling_config["xtc_threshold"]) params.sampling.xtc_threshold = sampling_config["xtc_threshold"].as<float>();
166+
if (sampling_config["typ_p"]) params.sampling.typ_p = sampling_config["typ_p"].as<float>();
167+
if (sampling_config["temp"]) params.sampling.temp = sampling_config["temp"].as<float>();
168+
if (sampling_config["dynatemp_range"]) params.sampling.dynatemp_range = sampling_config["dynatemp_range"].as<float>();
169+
if (sampling_config["dynatemp_exponent"]) params.sampling.dynatemp_exponent = sampling_config["dynatemp_exponent"].as<float>();
170+
if (sampling_config["penalty_last_n"]) params.sampling.penalty_last_n = sampling_config["penalty_last_n"].as<int32_t>();
171+
if (sampling_config["penalty_repeat"]) params.sampling.penalty_repeat = sampling_config["penalty_repeat"].as<float>();
172+
if (sampling_config["penalty_freq"]) params.sampling.penalty_freq = sampling_config["penalty_freq"].as<float>();
173+
if (sampling_config["penalty_present"]) params.sampling.penalty_present = sampling_config["penalty_present"].as<float>();
174+
if (sampling_config["dry_multiplier"]) params.sampling.dry_multiplier = sampling_config["dry_multiplier"].as<float>();
175+
if (sampling_config["dry_base"]) params.sampling.dry_base = sampling_config["dry_base"].as<float>();
176+
if (sampling_config["dry_allowed_length"]) params.sampling.dry_allowed_length = sampling_config["dry_allowed_length"].as<int32_t>();
177+
if (sampling_config["dry_penalty_last_n"]) params.sampling.dry_penalty_last_n = sampling_config["dry_penalty_last_n"].as<int32_t>();
178+
if (sampling_config["mirostat"]) params.sampling.mirostat = sampling_config["mirostat"].as<int32_t>();
179+
if (sampling_config["top_n_sigma"]) params.sampling.top_n_sigma = sampling_config["top_n_sigma"].as<float>();
180+
if (sampling_config["mirostat_tau"]) params.sampling.mirostat_tau = sampling_config["mirostat_tau"].as<float>();
181+
if (sampling_config["mirostat_eta"]) params.sampling.mirostat_eta = sampling_config["mirostat_eta"].as<float>();
182+
if (sampling_config["ignore_eos"]) params.sampling.ignore_eos = sampling_config["ignore_eos"].as<bool>();
183+
if (sampling_config["no_perf"]) params.sampling.no_perf = sampling_config["no_perf"].as<bool>();
184+
if (sampling_config["timing_per_token"]) params.sampling.timing_per_token = sampling_config["timing_per_token"].as<bool>();
185+
if (sampling_config["grammar"]) params.sampling.grammar = sampling_config["grammar"].as<std::string>();
186+
if (sampling_config["grammar_lazy"]) params.sampling.grammar_lazy = sampling_config["grammar_lazy"].as<bool>();
187+
188+
if (sampling_config["dry_sequence_breakers"]) {
189+
params.sampling.dry_sequence_breakers.clear();
190+
for (const auto & breaker : sampling_config["dry_sequence_breakers"]) {
191+
params.sampling.dry_sequence_breakers.push_back(breaker.as<std::string>());
192+
}
193+
}
194+
}
195+
196+
// Speculative parameters
197+
if (config["speculative"]) {
198+
const auto & spec_config = config["speculative"];
199+
if (spec_config["n_ctx"]) params.speculative.n_ctx = spec_config["n_ctx"].as<int32_t>();
200+
if (spec_config["n_max"]) params.speculative.n_max = spec_config["n_max"].as<int32_t>();
201+
if (spec_config["n_min"]) params.speculative.n_min = spec_config["n_min"].as<int32_t>();
202+
if (spec_config["n_gpu_layers"]) params.speculative.n_gpu_layers = spec_config["n_gpu_layers"].as<int32_t>();
203+
if (spec_config["p_split"]) params.speculative.p_split = spec_config["p_split"].as<float>();
204+
if (spec_config["p_min"]) params.speculative.p_min = spec_config["p_min"].as<float>();
205+
206+
if (spec_config["model"]) {
207+
const auto & model_config = spec_config["model"];
208+
if (model_config["path"]) params.speculative.model.path = model_config["path"].as<std::string>();
209+
if (model_config["url"]) params.speculative.model.url = model_config["url"].as<std::string>();
210+
if (model_config["hf_repo"]) params.speculative.model.hf_repo = model_config["hf_repo"].as<std::string>();
211+
if (model_config["hf_file"]) params.speculative.model.hf_file = model_config["hf_file"].as<std::string>();
212+
}
213+
}
214+
215+
if (config["antiprompt"]) {
216+
params.antiprompt.clear();
217+
for (const auto & antiprompt : config["antiprompt"]) {
218+
params.antiprompt.push_back(antiprompt.as<std::string>());
219+
}
220+
}
221+
222+
return true;
223+
} catch (const YAML::Exception & e) {
224+
LOG_ERR("Error parsing YAML config file '%s': %s\n", config_file.c_str(), e.what());
225+
return false;
226+
} catch (const std::exception & e) {
227+
LOG_ERR("Error loading YAML config file '%s': %s\n", config_file.c_str(), e.what());
228+
return false;
229+
}
230+
}
231+
68232
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
69233
this->examples = std::move(examples);
70234
return *this;
@@ -1227,6 +1391,20 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
12271391
ctx_arg.params = params_org;
12281392
return false;
12291393
}
1394+
1395+
// Load YAML config if specified
1396+
if (!ctx_arg.params.config_file.empty()) {
1397+
if (!common_params_load_from_yaml(ctx_arg.params.config_file, ctx_arg.params)) {
1398+
ctx_arg.params = params_org;
1399+
return false;
1400+
}
1401+
1402+
if (!common_params_parse_ex(argc, argv, ctx_arg)) {
1403+
ctx_arg.params = params_org;
1404+
return false;
1405+
}
1406+
}
1407+
12301408
if (ctx_arg.params.usage) {
12311409
common_params_print_usage(ctx_arg);
12321410
if (ctx_arg.print_usage) {
@@ -1317,6 +1495,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
13171495
params.completion = true;
13181496
}
13191497
));
1498+
add_opt(common_arg(
1499+
{"--config"}, "FNAME",
1500+
"path to a YAML config file (default: none)",
1501+
[](common_params & params, const std::string & value) {
1502+
params.config_file = value;
1503+
}
1504+
));
13201505
add_opt(common_arg(
13211506
{"--verbose-prompt"},
13221507
string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"),

common/arg.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ struct common_params_context {
7676
// if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message)
7777
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
7878

79+
// load parameters from YAML config file
80+
bool common_params_load_from_yaml(const std::string & config_file, common_params & params);
81+
7982
// function to be used by test-arg-parser
8083
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
8184
bool common_has_curl();

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ struct common_params {
332332
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
333333
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
334334
std::string logits_file = ""; // file for saving *all* logits // NOLINT
335+
std::string config_file = ""; // path to YAML config file // NOLINT
335336

336337
std::vector<std::string> in_files; // all input files
337338
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)

examples/config.yaml

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
2+
model:
3+
path: "models/my-model.gguf"
4+
5+
n_predict: 128
6+
n_ctx: 4096
7+
n_batch: 512
8+
n_ubatch: 512
9+
n_keep: 0
10+
n_gpu_layers: 32
11+
main_gpu: 0
12+
verbosity: 1
13+
14+
prompt: "Hello, how are you?"
15+
system_prompt: "You are a helpful assistant."
16+
input_prefix: "User: "
17+
input_suffix: "\nAssistant: "
18+
19+
rope_freq_base: 10000.0
20+
rope_freq_scale: 1.0
21+
yarn_ext_factor: 1.0
22+
yarn_attn_factor: 1.0
23+
yarn_beta_fast: 32.0
24+
yarn_beta_slow: 1.0
25+
26+
interactive: false
27+
interactive_first: false
28+
conversation: false
29+
use_color: true
30+
simple_io: false
31+
embedding: false
32+
escape: true
33+
multiline_input: false
34+
cont_batching: true
35+
flash_attn: false
36+
no_perf: false
37+
ctx_shift: true
38+
input_prefix_bos: false
39+
logits_all: false
40+
use_mmap: true
41+
use_mlock: false
42+
verbose_prompt: false
43+
display_prompt: true
44+
dump_kv_cache: false
45+
no_kv_offload: false
46+
warmup: true
47+
check_tensors: false
48+
49+
cache_type_k: "f16"
50+
cache_type_v: "f16"
51+
52+
cpuparams:
53+
n_threads: 8
54+
strict_cpu: false
55+
poll: 50
56+
57+
sampling:
58+
seed: 42
59+
n_prev: 64
60+
n_probs: 0
61+
min_keep: 0
62+
top_k: 40
63+
top_p: 0.95
64+
min_p: 0.05
65+
xtc_probability: 0.0
66+
xtc_threshold: 0.1
67+
typ_p: 1.0
68+
temp: 0.8
69+
dynatemp_range: 0.0
70+
dynatemp_exponent: 1.0
71+
penalty_last_n: 64
72+
penalty_repeat: 1.0
73+
penalty_freq: 0.0
74+
penalty_present: 0.0
75+
dry_multiplier: 0.0
76+
dry_base: 1.75
77+
dry_allowed_length: 2
78+
dry_penalty_last_n: -1
79+
mirostat: 0
80+
top_n_sigma: 0.0
81+
mirostat_tau: 5.0
82+
mirostat_eta: 0.1
83+
ignore_eos: false
84+
no_perf: false
85+
timing_per_token: false
86+
grammar: ""
87+
grammar_lazy: false
88+
dry_sequence_breakers:
89+
- "\n"
90+
- ":"
91+
- "\""
92+
93+
speculative:
94+
n_ctx: 0
95+
n_max: 16
96+
n_min: 5
97+
n_gpu_layers: 0
98+
p_split: 0.1
99+
p_min: 0.9
100+
model:
101+
path: ""
102+
103+
antiprompt:
104+
- "User:"
105+
- "Human:"
106+
- "\n\n"

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyll
190190
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
191191
if (NOT WIN32)
192192
llama_build_and_test(test-arg-parser.cpp)
193+
llama_build_and_test(test-yaml-config.cpp)
193194
endif()
194195

195196
if (NOT LLAMA_SANITIZE_ADDRESS)

0 commit comments

Comments
 (0)