Skip to content

Commit 2e91cef

Browse files
feat: add YAML configuration file support to CLI
- Add --config flag to accept YAML configuration files - Implement comprehensive YAML parsing with error handling - Maintain 100% backward compatibility with existing CLI args - CLI arguments override YAML configuration values - Add comprehensive test suite for YAML functionality - Update documentation with YAML configuration examples - Integrate yaml-cpp dependency via FetchContent The implementation allows users to specify configuration via YAML files while preserving all existing CLI functionality. CLI arguments take precedence over YAML values when both are provided. Co-Authored-By: Jake Cosme <[email protected]>
1 parent 661ae31 commit 2e91cef

File tree

7 files changed

+623
-1
lines changed

7 files changed

+623
-1
lines changed

CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
22
project("llama.cpp" C CXX)
33
include(CheckIncludeFileCXX)
4+
include(FetchContent)
45

56
#set(CMAKE_WARN_DEPRECATED YES)
67
set(CMAKE_WARN_UNUSED_CLI YES)
@@ -87,6 +88,14 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
8788
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
8889
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
8990

91+
# Add yaml-cpp dependency
92+
FetchContent_Declare(
93+
yaml-cpp
94+
GIT_REPOSITORY https://github.com/jbeder/yaml-cpp.git
95+
GIT_TAG yaml-cpp-0.7.0
96+
)
97+
FetchContent_MakeAvailable(yaml-cpp)
98+
9099
# Required for relocatable CMake package
91100
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
92101
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake)

common/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ endif ()
135135

136136
target_include_directories(${TARGET} PUBLIC . ../vendor)
137137
target_compile_features (${TARGET} PUBLIC cxx_std_17)
138-
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
138+
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} yaml-cpp PUBLIC llama Threads::Threads)
139139

140140

141141
#

common/arg.cpp

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include "log.h"
88
#include "sampling.h"
99

10+
#include <yaml-cpp/yaml.h>
11+
1012
// fix problem with std::min and std::max
1113
#if defined(_WIN32)
1214
#define WIN32_LEAN_AND_MEAN
@@ -41,6 +43,200 @@
4143

4244
using json = nlohmann::ordered_json;
4345

46+
// YAML configuration parsing functions
47+
static void parse_yaml_sampling(const YAML::Node& node, common_params_sampling& sampling) {
48+
if (node["seed"]) sampling.seed = node["seed"].as<uint32_t>();
49+
if (node["n_prev"]) sampling.n_prev = node["n_prev"].as<int32_t>();
50+
if (node["n_probs"]) sampling.n_probs = node["n_probs"].as<int32_t>();
51+
if (node["min_keep"]) sampling.min_keep = node["min_keep"].as<int32_t>();
52+
if (node["top_k"]) sampling.top_k = node["top_k"].as<int32_t>();
53+
if (node["top_p"]) sampling.top_p = node["top_p"].as<float>();
54+
if (node["min_p"]) sampling.min_p = node["min_p"].as<float>();
55+
if (node["xtc_probability"]) sampling.xtc_probability = node["xtc_probability"].as<float>();
56+
if (node["xtc_threshold"]) sampling.xtc_threshold = node["xtc_threshold"].as<float>();
57+
if (node["typ_p"]) sampling.typ_p = node["typ_p"].as<float>();
58+
if (node["temp"]) sampling.temp = node["temp"].as<float>();
59+
if (node["dynatemp_range"]) sampling.dynatemp_range = node["dynatemp_range"].as<float>();
60+
if (node["dynatemp_exponent"]) sampling.dynatemp_exponent = node["dynatemp_exponent"].as<float>();
61+
if (node["penalty_last_n"]) sampling.penalty_last_n = node["penalty_last_n"].as<int32_t>();
62+
if (node["penalty_repeat"]) sampling.penalty_repeat = node["penalty_repeat"].as<float>();
63+
if (node["penalty_freq"]) sampling.penalty_freq = node["penalty_freq"].as<float>();
64+
if (node["penalty_present"]) sampling.penalty_present = node["penalty_present"].as<float>();
65+
if (node["dry_multiplier"]) sampling.dry_multiplier = node["dry_multiplier"].as<float>();
66+
if (node["dry_base"]) sampling.dry_base = node["dry_base"].as<float>();
67+
if (node["dry_allowed_length"]) sampling.dry_allowed_length = node["dry_allowed_length"].as<int32_t>();
68+
if (node["dry_penalty_last_n"]) sampling.dry_penalty_last_n = node["dry_penalty_last_n"].as<int32_t>();
69+
if (node["mirostat"]) sampling.mirostat = node["mirostat"].as<int32_t>();
70+
if (node["top_n_sigma"]) sampling.top_n_sigma = node["top_n_sigma"].as<float>();
71+
if (node["mirostat_tau"]) sampling.mirostat_tau = node["mirostat_tau"].as<float>();
72+
if (node["mirostat_eta"]) sampling.mirostat_eta = node["mirostat_eta"].as<float>();
73+
if (node["ignore_eos"]) sampling.ignore_eos = node["ignore_eos"].as<bool>();
74+
if (node["no_perf"]) sampling.no_perf = node["no_perf"].as<bool>();
75+
if (node["timing_per_token"]) sampling.timing_per_token = node["timing_per_token"].as<bool>();
76+
if (node["grammar"]) sampling.grammar = node["grammar"].as<std::string>();
77+
if (node["grammar_lazy"]) sampling.grammar_lazy = node["grammar_lazy"].as<bool>();
78+
79+
if (node["dry_sequence_breakers"] && node["dry_sequence_breakers"].IsSequence()) {
80+
sampling.dry_sequence_breakers.clear();
81+
for (const auto& breaker : node["dry_sequence_breakers"]) {
82+
sampling.dry_sequence_breakers.push_back(breaker.as<std::string>());
83+
}
84+
}
85+
}
86+
87+
static void parse_yaml_model(const YAML::Node& node, common_params_model& model) {
88+
if (node["path"]) model.path = node["path"].as<std::string>();
89+
if (node["url"]) model.url = node["url"].as<std::string>();
90+
if (node["hf_repo"]) model.hf_repo = node["hf_repo"].as<std::string>();
91+
if (node["hf_file"]) model.hf_file = node["hf_file"].as<std::string>();
92+
}
93+
94+
static void parse_yaml_speculative(const YAML::Node& node, common_params_speculative& spec) {
95+
if (node["n_ctx"]) spec.n_ctx = node["n_ctx"].as<int32_t>();
96+
if (node["n_max"]) spec.n_max = node["n_max"].as<int32_t>();
97+
if (node["n_min"]) spec.n_min = node["n_min"].as<int32_t>();
98+
if (node["n_gpu_layers"]) spec.n_gpu_layers = node["n_gpu_layers"].as<int32_t>();
99+
if (node["p_split"]) spec.p_split = node["p_split"].as<float>();
100+
if (node["p_min"]) spec.p_min = node["p_min"].as<float>();
101+
if (node["cache_type_k"]) {
102+
std::string cache_type = node["cache_type_k"].as<std::string>();
103+
if (cache_type == "f16") spec.cache_type_k = GGML_TYPE_F16;
104+
else if (cache_type == "f32") spec.cache_type_k = GGML_TYPE_F32;
105+
else if (cache_type == "q4_0") spec.cache_type_k = GGML_TYPE_Q4_0;
106+
else if (cache_type == "q4_1") spec.cache_type_k = GGML_TYPE_Q4_1;
107+
else if (cache_type == "q5_0") spec.cache_type_k = GGML_TYPE_Q5_0;
108+
else if (cache_type == "q5_1") spec.cache_type_k = GGML_TYPE_Q5_1;
109+
else if (cache_type == "q8_0") spec.cache_type_k = GGML_TYPE_Q8_0;
110+
}
111+
if (node["cache_type_v"]) {
112+
std::string cache_type = node["cache_type_v"].as<std::string>();
113+
if (cache_type == "f16") spec.cache_type_v = GGML_TYPE_F16;
114+
else if (cache_type == "f32") spec.cache_type_v = GGML_TYPE_F32;
115+
else if (cache_type == "q4_0") spec.cache_type_v = GGML_TYPE_Q4_0;
116+
else if (cache_type == "q4_1") spec.cache_type_v = GGML_TYPE_Q4_1;
117+
else if (cache_type == "q5_0") spec.cache_type_v = GGML_TYPE_Q5_0;
118+
else if (cache_type == "q5_1") spec.cache_type_v = GGML_TYPE_Q5_1;
119+
else if (cache_type == "q8_0") spec.cache_type_v = GGML_TYPE_Q8_0;
120+
}
121+
if (node["model"]) {
122+
parse_yaml_model(node["model"], spec.model);
123+
}
124+
}
125+
126+
static void parse_yaml_vocoder(const YAML::Node& node, common_params_vocoder& vocoder) {
127+
if (node["speaker_file"]) vocoder.speaker_file = node["speaker_file"].as<std::string>();
128+
if (node["use_guide_tokens"]) vocoder.use_guide_tokens = node["use_guide_tokens"].as<bool>();
129+
if (node["model"]) {
130+
parse_yaml_model(node["model"], vocoder.model);
131+
}
132+
}
133+
134+
static void parse_yaml_diffusion(const YAML::Node& node, common_params_diffusion& diffusion) {
135+
if (node["steps"]) diffusion.steps = node["steps"].as<int32_t>();
136+
if (node["visual_mode"]) diffusion.visual_mode = node["visual_mode"].as<bool>();
137+
if (node["eps"]) diffusion.eps = node["eps"].as<float>();
138+
if (node["block_length"]) diffusion.block_length = node["block_length"].as<int32_t>();
139+
if (node["algorithm"]) diffusion.algorithm = node["algorithm"].as<int32_t>();
140+
if (node["alg_temp"]) diffusion.alg_temp = node["alg_temp"].as<float>();
141+
if (node["cfg_scale"]) diffusion.cfg_scale = node["cfg_scale"].as<float>();
142+
if (node["add_gumbel_noise"]) diffusion.add_gumbel_noise = node["add_gumbel_noise"].as<bool>();
143+
}
144+
145+
static bool load_yaml_config(const std::string& config_path, common_params& params) {
146+
try {
147+
YAML::Node config = YAML::LoadFile(config_path);
148+
149+
// Parse main parameters
150+
if (config["n_predict"]) params.n_predict = config["n_predict"].as<int32_t>();
151+
if (config["n_ctx"]) params.n_ctx = config["n_ctx"].as<int32_t>();
152+
if (config["n_batch"]) params.n_batch = config["n_batch"].as<int32_t>();
153+
if (config["n_ubatch"]) params.n_ubatch = config["n_ubatch"].as<int32_t>();
154+
if (config["n_keep"]) params.n_keep = config["n_keep"].as<int32_t>();
155+
if (config["n_chunks"]) params.n_chunks = config["n_chunks"].as<int32_t>();
156+
if (config["n_parallel"]) params.n_parallel = config["n_parallel"].as<int32_t>();
157+
if (config["n_sequences"]) params.n_sequences = config["n_sequences"].as<int32_t>();
158+
if (config["grp_attn_n"]) params.grp_attn_n = config["grp_attn_n"].as<int32_t>();
159+
if (config["grp_attn_w"]) params.grp_attn_w = config["grp_attn_w"].as<int32_t>();
160+
if (config["n_print"]) params.n_print = config["n_print"].as<int32_t>();
161+
if (config["rope_freq_base"]) params.rope_freq_base = config["rope_freq_base"].as<float>();
162+
if (config["rope_freq_scale"]) params.rope_freq_scale = config["rope_freq_scale"].as<float>();
163+
if (config["yarn_ext_factor"]) params.yarn_ext_factor = config["yarn_ext_factor"].as<float>();
164+
if (config["yarn_attn_factor"]) params.yarn_attn_factor = config["yarn_attn_factor"].as<float>();
165+
if (config["yarn_beta_fast"]) params.yarn_beta_fast = config["yarn_beta_fast"].as<float>();
166+
if (config["yarn_beta_slow"]) params.yarn_beta_slow = config["yarn_beta_slow"].as<float>();
167+
if (config["yarn_orig_ctx"]) params.yarn_orig_ctx = config["yarn_orig_ctx"].as<int32_t>();
168+
if (config["n_gpu_layers"]) params.n_gpu_layers = config["n_gpu_layers"].as<int32_t>();
169+
if (config["main_gpu"]) params.main_gpu = config["main_gpu"].as<int32_t>();
170+
171+
// Parse string parameters
172+
if (config["model_alias"]) params.model_alias = config["model_alias"].as<std::string>();
173+
if (config["hf_token"]) params.hf_token = config["hf_token"].as<std::string>();
174+
if (config["prompt"]) params.prompt = config["prompt"].as<std::string>();
175+
if (config["system_prompt"]) params.system_prompt = config["system_prompt"].as<std::string>();
176+
if (config["prompt_file"]) params.prompt_file = config["prompt_file"].as<std::string>();
177+
if (config["path_prompt_cache"]) params.path_prompt_cache = config["path_prompt_cache"].as<std::string>();
178+
if (config["input_prefix"]) params.input_prefix = config["input_prefix"].as<std::string>();
179+
if (config["input_suffix"]) params.input_suffix = config["input_suffix"].as<std::string>();
180+
if (config["lookup_cache_static"]) params.lookup_cache_static = config["lookup_cache_static"].as<std::string>();
181+
if (config["lookup_cache_dynamic"]) params.lookup_cache_dynamic = config["lookup_cache_dynamic"].as<std::string>();
182+
if (config["logits_file"]) params.logits_file = config["logits_file"].as<std::string>();
183+
184+
// Parse boolean parameters
185+
if (config["lora_init_without_apply"]) params.lora_init_without_apply = config["lora_init_without_apply"].as<bool>();
186+
if (config["offline"]) params.offline = config["offline"].as<bool>();
187+
188+
// Parse integer parameters
189+
if (config["verbosity"]) params.verbosity = config["verbosity"].as<int32_t>();
190+
if (config["control_vector_layer_start"]) params.control_vector_layer_start = config["control_vector_layer_start"].as<int32_t>();
191+
if (config["control_vector_layer_end"]) params.control_vector_layer_end = config["control_vector_layer_end"].as<int32_t>();
192+
if (config["ppl_stride"]) params.ppl_stride = config["ppl_stride"].as<int32_t>();
193+
if (config["ppl_output_type"]) params.ppl_output_type = config["ppl_output_type"].as<int32_t>();
194+
195+
// Parse array parameters
196+
if (config["in_files"] && config["in_files"].IsSequence()) {
197+
params.in_files.clear();
198+
for (const auto& file : config["in_files"]) {
199+
params.in_files.push_back(file.as<std::string>());
200+
}
201+
}
202+
203+
if (config["antiprompt"] && config["antiprompt"].IsSequence()) {
204+
params.antiprompt.clear();
205+
for (const auto& prompt : config["antiprompt"]) {
206+
params.antiprompt.push_back(prompt.as<std::string>());
207+
}
208+
}
209+
210+
if (config["sampling"]) {
211+
parse_yaml_sampling(config["sampling"], params.sampling);
212+
}
213+
214+
if (config["model"]) {
215+
parse_yaml_model(config["model"], params.model);
216+
}
217+
218+
if (config["speculative"]) {
219+
parse_yaml_speculative(config["speculative"], params.speculative);
220+
}
221+
222+
if (config["vocoder"]) {
223+
parse_yaml_vocoder(config["vocoder"], params.vocoder);
224+
}
225+
226+
if (config["diffusion"]) {
227+
parse_yaml_diffusion(config["diffusion"], params.diffusion);
228+
}
229+
230+
return true;
231+
} catch (const YAML::Exception& e) {
232+
fprintf(stderr, "YAML parsing error: %s\n", e.what());
233+
return false;
234+
} catch (const std::exception& e) {
235+
fprintf(stderr, "Error loading YAML config: %s\n", e.what());
236+
return false;
237+
}
238+
}
239+
44240
std::initializer_list<enum llama_example> mmproj_examples = {
45241
LLAMA_EXAMPLE_MTMD,
46242
LLAMA_EXAMPLE_SERVER,
@@ -1223,6 +1419,17 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
12231419
const common_params params_org = ctx_arg.params; // the example can modify the default params
12241420

12251421
try {
1422+
for (int i = 1; i < argc; i++) {
1423+
if (strcmp(argv[i], "--config") == 0 && i + 1 < argc) {
1424+
if (!load_yaml_config(argv[i + 1], ctx_arg.params)) {
1425+
fprintf(stderr, "Failed to load YAML config: %s\n", argv[i + 1]);
1426+
ctx_arg.params = params_org;
1427+
return false;
1428+
}
1429+
break;
1430+
}
1431+
}
1432+
12261433
if (!common_params_parse_ex(argc, argv, ctx_arg)) {
12271434
ctx_arg.params = params_org;
12281435
return false;
@@ -1294,6 +1501,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12941501
};
12951502

12961503

1504+
add_opt(common_arg(
1505+
{"--config"},
1506+
"FNAME",
1507+
"path to YAML configuration file",
1508+
[](common_params & params, const std::string & value) {
1509+
params.config_file = value;
1510+
}
1511+
));
1512+
12971513
add_opt(common_arg(
12981514
{"-h", "--help", "--usage"},
12991515
"print usage and exit",

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ struct common_params {
348348
int32_t control_vector_layer_end = -1; // layer range for control vector
349349
bool offline = false;
350350

351+
std::string config_file = ""; // path to YAML configuration file
352+
351353
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
352354
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
353355
// (which is more convenient to use for plotting)

tests/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyll
190190
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
191191
if (NOT WIN32)
192192
llama_build_and_test(test-arg-parser.cpp)
193+
194+
# YAML configuration tests
195+
llama_build_and_test(test-yaml-config.cpp)
196+
llama_build_and_test(test-yaml-backward-compat.cpp)
193197
endif()
194198

195199
if (NOT LLAMA_SANITIZE_ADDRESS)

0 commit comments

Comments
 (0)