Skip to content

Commit 1af90f6

Browse files
Add YAML configuration file support to CLI tools
- Add yaml-cpp dependency with proper CMake integration - Add --config argument to accept YAML configuration files - Implement comprehensive YAML parameter parsing in common_params_load_from_yaml - Support all major CLI parameters: model, threads, context size, sampling, etc. - CLI arguments override config file values (proper precedence) - Add comprehensive tests for YAML functionality and backward compatibility - Include example YAML config files for main and server tools - Maintain full backward compatibility with existing CLI usage - Add proper error handling for invalid YAML files and missing files Co-Authored-By: Jaime Mizrachi <[email protected]>
1 parent 661ae31 commit 1af90f6

File tree

8 files changed

+379
-0
lines changed

8 files changed

+379
-0
lines changed

CMakeLists.txt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
8686
# 3rd party libs
8787
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
8888
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
89+
option(LLAMA_YAML_CPP "llama: use yaml-cpp for YAML config file support" ON)
8990

9091
# Required for relocatable CMake package
9192
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
@@ -190,6 +191,30 @@ if (NOT LLAMA_BUILD_COMMON)
190191
set(LLAMA_CURL OFF)
191192
endif()
192193

194+
# Find yaml-cpp if enabled
195+
if (LLAMA_YAML_CPP)
196+
find_package(PkgConfig QUIET)
197+
if (PkgConfig_FOUND)
198+
pkg_check_modules(YAML_CPP QUIET yaml-cpp)
199+
endif()
200+
201+
if (NOT YAML_CPP_FOUND)
202+
find_package(yaml-cpp QUIET)
203+
if (yaml-cpp_FOUND)
204+
set(YAML_CPP_LIBRARIES yaml-cpp)
205+
set(YAML_CPP_INCLUDE_DIRS ${yaml-cpp_INCLUDE_DIRS})
206+
endif()
207+
endif()
208+
209+
if (NOT YAML_CPP_FOUND AND NOT yaml-cpp_FOUND)
210+
message(STATUS "yaml-cpp not found, disabling YAML config support")
211+
set(LLAMA_YAML_CPP OFF)
212+
else()
213+
message(STATUS "yaml-cpp found, enabling YAML config support")
214+
add_compile_definitions(LLAMA_YAML_CPP)
215+
endif()
216+
endif()
217+
193218
if (LLAMA_BUILD_COMMON)
194219
add_subdirectory(common)
195220
endif()

common/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,15 @@ target_include_directories(${TARGET} PUBLIC . ../vendor)
137137
target_compile_features (${TARGET} PUBLIC cxx_std_17)
138138
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
139139

140+
if (LLAMA_YAML_CPP AND YAML_CPP_FOUND)
141+
target_link_libraries(${TARGET} PRIVATE ${YAML_CPP_LIBRARIES})
142+
target_include_directories(${TARGET} PRIVATE ${YAML_CPP_INCLUDE_DIRS})
143+
target_compile_definitions(${TARGET} PRIVATE LLAMA_YAML_CPP)
144+
elseif (LLAMA_YAML_CPP AND yaml-cpp_FOUND)
145+
target_link_libraries(${TARGET} PRIVATE yaml-cpp)
146+
target_compile_definitions(${TARGET} PRIVATE LLAMA_YAML_CPP)
147+
endif()
148+
140149

141150
#
142151
# copy the license files

common/arg.cpp

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
#define JSON_ASSERT GGML_ASSERT
2020
#include <nlohmann/json.hpp>
2121

22+
#ifdef LLAMA_YAML_CPP
23+
#include <yaml-cpp/yaml.h>
24+
#endif
25+
2226
#include <algorithm>
2327
#include <climits>
2428
#include <cstdarg>
@@ -65,6 +69,177 @@ static void write_file(const std::string & fname, const std::string & content) {
6569
file.close();
6670
}
6771

72+
#ifdef LLAMA_YAML_CPP
73+
static bool common_params_load_from_yaml(const std::string & config_file, common_params & params) {
74+
if (config_file.empty()) {
75+
return true;
76+
}
77+
78+
try {
79+
YAML::Node config = YAML::LoadFile(config_file);
80+
81+
// Model parameters
82+
if (config["model"]) {
83+
params.model.path = config["model"].as<std::string>();
84+
}
85+
if (config["model_url"]) {
86+
params.model.url = config["model_url"].as<std::string>();
87+
}
88+
if (config["model_alias"]) {
89+
params.model_alias = config["model_alias"].as<std::string>();
90+
}
91+
if (config["hf_repo"]) {
92+
params.model.hf_repo = config["hf_repo"].as<std::string>();
93+
}
94+
if (config["hf_file"]) {
95+
params.model.hf_file = config["hf_file"].as<std::string>();
96+
}
97+
if (config["hf_token"]) {
98+
params.hf_token = config["hf_token"].as<std::string>();
99+
}
100+
101+
// Context and prediction parameters
102+
if (config["ctx_size"]) {
103+
params.n_ctx = config["ctx_size"].as<int32_t>();
104+
}
105+
if (config["predict"]) {
106+
params.n_predict = config["predict"].as<int32_t>();
107+
}
108+
if (config["batch_size"]) {
109+
params.n_batch = config["batch_size"].as<int32_t>();
110+
}
111+
if (config["ubatch_size"]) {
112+
params.n_ubatch = config["ubatch_size"].as<int32_t>();
113+
}
114+
if (config["keep"]) {
115+
params.n_keep = config["keep"].as<int32_t>();
116+
}
117+
if (config["chunks"]) {
118+
params.n_chunks = config["chunks"].as<int32_t>();
119+
}
120+
if (config["parallel"]) {
121+
params.n_parallel = config["parallel"].as<int32_t>();
122+
}
123+
if (config["sequences"]) {
124+
params.n_sequences = config["sequences"].as<int32_t>();
125+
}
126+
127+
// CPU parameters
128+
if (config["threads"]) {
129+
params.cpuparams.n_threads = config["threads"].as<int>();
130+
}
131+
if (config["threads_batch"]) {
132+
params.cpuparams_batch.n_threads = config["threads_batch"].as<int>();
133+
}
134+
135+
// GPU parameters
136+
if (config["n_gpu_layers"]) {
137+
params.n_gpu_layers = config["n_gpu_layers"].as<int32_t>();
138+
}
139+
if (config["main_gpu"]) {
140+
params.main_gpu = config["main_gpu"].as<int32_t>();
141+
}
142+
143+
// Sampling parameters
144+
if (config["seed"]) {
145+
params.sampling.seed = config["seed"].as<uint32_t>();
146+
}
147+
if (config["temperature"]) {
148+
params.sampling.temp = config["temperature"].as<float>();
149+
}
150+
if (config["top_k"]) {
151+
params.sampling.top_k = config["top_k"].as<int32_t>();
152+
}
153+
if (config["top_p"]) {
154+
params.sampling.top_p = config["top_p"].as<float>();
155+
}
156+
if (config["min_p"]) {
157+
params.sampling.min_p = config["min_p"].as<float>();
158+
}
159+
if (config["typical_p"]) {
160+
params.sampling.typ_p = config["typical_p"].as<float>();
161+
}
162+
if (config["repeat_last_n"]) {
163+
params.sampling.penalty_last_n = config["repeat_last_n"].as<int32_t>();
164+
}
165+
if (config["repeat_penalty"]) {
166+
params.sampling.penalty_repeat = config["repeat_penalty"].as<float>();
167+
}
168+
if (config["frequency_penalty"]) {
169+
params.sampling.penalty_freq = config["frequency_penalty"].as<float>();
170+
}
171+
if (config["presence_penalty"]) {
172+
params.sampling.penalty_present = config["presence_penalty"].as<float>();
173+
}
174+
if (config["mirostat"]) {
175+
params.sampling.mirostat = config["mirostat"].as<int32_t>();
176+
}
177+
if (config["mirostat_tau"]) {
178+
params.sampling.mirostat_tau = config["mirostat_tau"].as<float>();
179+
}
180+
if (config["mirostat_eta"]) {
181+
params.sampling.mirostat_eta = config["mirostat_eta"].as<float>();
182+
}
183+
184+
// Prompt and system parameters
185+
if (config["prompt"]) {
186+
params.prompt = config["prompt"].as<std::string>();
187+
}
188+
if (config["system_prompt"]) {
189+
params.system_prompt = config["system_prompt"].as<std::string>();
190+
}
191+
if (config["prompt_file"]) {
192+
params.prompt_file = config["prompt_file"].as<std::string>();
193+
}
194+
if (config["prompt_cache"]) {
195+
params.path_prompt_cache = config["prompt_cache"].as<std::string>();
196+
}
197+
198+
// Input/Output parameters
199+
if (config["input_prefix"]) {
200+
params.input_prefix = config["input_prefix"].as<std::string>();
201+
}
202+
if (config["input_suffix"]) {
203+
params.input_suffix = config["input_suffix"].as<std::string>();
204+
}
205+
206+
if (config["verbose"]) {
207+
params.verbosity = config["verbose"].as<int32_t>();
208+
}
209+
210+
if (config["conversation"]) {
211+
bool conv = config["conversation"].as<bool>();
212+
params.conversation_mode = conv ? COMMON_CONVERSATION_MODE_ENABLED : COMMON_CONVERSATION_MODE_DISABLED;
213+
}
214+
215+
if (config["interactive"]) {
216+
params.interactive = config["interactive"].as<bool>();
217+
}
218+
if (config["interactive_first"]) {
219+
params.interactive_first = config["interactive_first"].as<bool>();
220+
}
221+
222+
if (config["antiprompt"]) {
223+
if (config["antiprompt"].IsSequence()) {
224+
for (const auto & item : config["antiprompt"]) {
225+
params.antiprompt.push_back(item.as<std::string>());
226+
}
227+
} else {
228+
params.antiprompt.push_back(config["antiprompt"].as<std::string>());
229+
}
230+
}
231+
232+
return true;
233+
} catch (const YAML::Exception & e) {
234+
fprintf(stderr, "Error parsing YAML config file '%s': %s\n", config_file.c_str(), e.what());
235+
return false;
236+
} catch (const std::exception & e) {
237+
fprintf(stderr, "Error loading YAML config file '%s': %s\n", config_file.c_str(), e.what());
238+
return false;
239+
}
240+
}
241+
#endif
242+
68243
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
69244
this->examples = std::move(examples);
70245
return *this;
@@ -1301,6 +1476,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
13011476
params.usage = true;
13021477
}
13031478
));
1479+
1480+
#ifdef LLAMA_YAML_CPP
1481+
add_opt(common_arg(
1482+
{"--config"},
1483+
"CONFIG_FILE",
1484+
"path to YAML configuration file",
1485+
[](common_params & params, const std::string & value) {
1486+
params.config_file = value;
1487+
if (!common_params_load_from_yaml(value, params)) {
1488+
throw std::invalid_argument("failed to load YAML config file: " + value);
1489+
}
1490+
}
1491+
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
1492+
#endif
1493+
13041494
add_opt(common_arg(
13051495
{"--version"},
13061496
"show version and build info",

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ struct common_params {
332332
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
333333
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
334334
std::string logits_file = ""; // file for saving *all* logits // NOLINT
335+
std::string config_file = ""; // path to YAML configuration file // NOLINT
335336

336337
std::vector<std::string> in_files; // all input files
337338
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)

examples/config.yaml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
2+
model: "models/7B/ggml-model-f16.gguf"
3+
4+
ctx_size: 2048 # Context size (number of tokens)
5+
predict: 128 # Number of tokens to predict (-1 for unlimited)
6+
batch_size: 512 # Batch size for prompt processing
7+
ubatch_size: 512 # Physical batch size
8+
keep: 0 # Number of tokens to keep from initial prompt
9+
chunks: -1 # Max number of chunks to process (-1 = unlimited)
10+
parallel: 1 # Number of parallel sequences
11+
sequences: 1 # Number of sequences to decode
12+
13+
threads: 4 # Number of threads to use
14+
threads_batch: 4 # Number of threads for batch processing
15+
16+
n_gpu_layers: -1 # Number of layers to offload to GPU (-1 = all)
17+
main_gpu: 0 # Main GPU to use
18+
19+
seed: -1 # Random seed (-1 for random)
20+
temperature: 0.8 # Sampling temperature
21+
top_k: 40 # Top-k sampling
22+
top_p: 0.95 # Top-p (nucleus) sampling
23+
min_p: 0.05 # Min-p sampling
24+
typical_p: 1.0 # Typical-p sampling
25+
repeat_last_n: 64 # Last n tokens to consider for repetition penalty
26+
repeat_penalty: 1.1 # Repetition penalty
27+
frequency_penalty: 0.0 # Frequency penalty
28+
presence_penalty: 0.0 # Presence penalty
29+
mirostat: 0 # Mirostat sampling (0=disabled, 1=v1, 2=v2)
30+
mirostat_tau: 5.0 # Mirostat target entropy
31+
mirostat_eta: 0.1 # Mirostat learning rate
32+
33+
34+
35+
verbose: 0 # Verbosity level (0=quiet, 1=normal, 2=verbose)
36+
conversation: false # Enable conversation mode
37+
interactive: false # Enable interactive mode
38+
interactive_first: false # Start in interactive mode
39+
40+
antiprompt:
41+
- "User:"
42+
- "Human:"
43+
- "\n\n"

tests/test-arg-parser.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <vector>
66
#include <sstream>
77
#include <unordered_set>
8+
#include <fstream>
89

910
#undef NDEBUG
1011
#include <cassert>
@@ -174,5 +175,73 @@ int main(void) {
174175
printf("test-arg-parser: no curl, skipping curl-related functions\n");
175176
}
176177

178+
printf("test-arg-parser: all tests OK\n\n");
179+
180+
#ifdef LLAMA_YAML_CPP
181+
printf("test-arg-parser: testing YAML config functionality\n\n");
182+
183+
std::string yaml_content = R"(
184+
model: "test_model.gguf"
185+
threads: 8
186+
ctx_size: 4096
187+
predict: 256
188+
temperature: 0.7
189+
top_k: 50
190+
top_p: 0.9
191+
seed: 12345
192+
verbose: 1
193+
conversation: true
194+
antiprompt:
195+
- "User:"
196+
- "Stop"
197+
)";
198+
199+
std::string temp_config = "/tmp/test_config.yaml";
200+
std::ofstream config_file(temp_config);
201+
config_file << yaml_content;
202+
config_file.close();
203+
204+
argv = {"binary_name", "--config", temp_config.c_str()};
205+
assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
206+
assert(params.model.path == "test_model.gguf");
207+
assert(params.cpuparams.n_threads == 8);
208+
assert(params.n_ctx == 4096);
209+
assert(params.n_predict == 256);
210+
assert(params.sampling.temp == 0.7f);
211+
assert(params.sampling.top_k == 50);
212+
assert(params.sampling.top_p == 0.9f);
213+
assert(params.sampling.seed == 12345);
214+
assert(params.verbosity == 1);
215+
assert(params.conversation_mode == COMMON_CONVERSATION_MODE_ENABLED);
216+
assert(params.antiprompt.size() == 2);
217+
assert(params.antiprompt[0] == "User:");
218+
assert(params.antiprompt[1] == "Stop");
219+
220+
argv = {"binary_name", "--config", temp_config.c_str(), "-t", "16", "--ctx-size", "8192"};
221+
assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
222+
assert(params.model.path == "test_model.gguf"); // from config
223+
assert(params.cpuparams.n_threads == 16); // overridden by CLI
224+
assert(params.n_ctx == 8192); // overridden by CLI
225+
assert(params.sampling.temp == 0.7f); // from config
226+
227+
std::string invalid_yaml = "/tmp/invalid_config.yaml";
228+
std::ofstream invalid_file(invalid_yaml);
229+
invalid_file << "invalid: yaml: content: [unclosed";
230+
invalid_file.close();
231+
232+
argv = {"binary_name", "--config", invalid_yaml.c_str()};
233+
assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
234+
235+
argv = {"binary_name", "--config", "/tmp/nonexistent_config.yaml"};
236+
assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
237+
238+
std::remove(temp_config.c_str());
239+
std::remove(invalid_yaml.c_str());
240+
241+
printf("test-arg-parser: YAML config tests passed\n\n");
242+
#else
243+
printf("test-arg-parser: YAML config support not compiled, skipping YAML tests\n\n");
244+
#endif
245+
177246
printf("test-arg-parser: all tests OK\n\n");
178247
}

0 commit comments

Comments
 (0)