Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,10 @@ insert_final_newline = unset
[vendor/miniaudio/miniaudio.h]
trim_trailing_whitespace = unset
insert_final_newline = unset

[vendor/yaml-cpp/**]
trim_trailing_whitespace = unset
insert_final_newline = unset
indent_style = unset
end_of_line = unset
charset = unset
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ endif()
# build the library
#

add_subdirectory(vendor/yaml-cpp)

add_subdirectory(src)

#
Expand Down
2 changes: 1 addition & 1 deletion common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ endif ()

target_include_directories(${TARGET} PUBLIC . ../vendor)
target_compile_features (${TARGET} PUBLIC cxx_std_17)
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} yaml-cpp PUBLIC llama Threads::Threads)


#
Expand Down
185 changes: 183 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#define JSON_ASSERT GGML_ASSERT
#include <nlohmann/json.hpp>
#include <yaml-cpp/yaml.h>

#include <algorithm>
#include <climits>
Expand Down Expand Up @@ -65,6 +66,166 @@ static void write_file(const std::string & fname, const std::string & content) {
file.close();
}

bool common_params_load_from_yaml(const std::string & config_file, common_params & params) {
try {
YAML::Node config = YAML::LoadFile(config_file);
// Model parameters
if (config["model"]) {
if (config["model"]["path"]) {
params.model.path = config["model"]["path"].as<std::string>();
}
if (config["model"]["url"]) {
params.model.url = config["model"]["url"].as<std::string>();
}
if (config["model"]["hf_repo"]) {
params.model.hf_repo = config["model"]["hf_repo"].as<std::string>();
}
if (config["model"]["hf_file"]) {
params.model.hf_file = config["model"]["hf_file"].as<std::string>();
}
}

// Basic parameters
if (config["n_predict"]) params.n_predict = config["n_predict"].as<int32_t>();
if (config["n_ctx"]) params.n_ctx = config["n_ctx"].as<int32_t>();
if (config["n_batch"]) params.n_batch = config["n_batch"].as<int32_t>();
if (config["n_ubatch"]) params.n_ubatch = config["n_ubatch"].as<int32_t>();
if (config["n_keep"]) params.n_keep = config["n_keep"].as<int32_t>();
if (config["n_chunks"]) params.n_chunks = config["n_chunks"].as<int32_t>();
if (config["n_parallel"]) params.n_parallel = config["n_parallel"].as<int32_t>();
if (config["n_sequences"]) params.n_sequences = config["n_sequences"].as<int32_t>();
if (config["n_gpu_layers"]) params.n_gpu_layers = config["n_gpu_layers"].as<int32_t>();
if (config["main_gpu"]) params.main_gpu = config["main_gpu"].as<int32_t>();
if (config["verbosity"]) params.verbosity = config["verbosity"].as<int32_t>();
// String parameters
if (config["prompt"]) params.prompt = config["prompt"].as<std::string>();
if (config["system_prompt"]) params.system_prompt = config["system_prompt"].as<std::string>();
if (config["prompt_file"]) params.prompt_file = config["prompt_file"].as<std::string>();
if (config["input_prefix"]) params.input_prefix = config["input_prefix"].as<std::string>();
if (config["input_suffix"]) params.input_suffix = config["input_suffix"].as<std::string>();
if (config["hf_token"]) params.hf_token = config["hf_token"].as<std::string>();
// Float parameters
if (config["rope_freq_base"]) params.rope_freq_base = config["rope_freq_base"].as<float>();
if (config["rope_freq_scale"]) params.rope_freq_scale = config["rope_freq_scale"].as<float>();
if (config["yarn_ext_factor"]) params.yarn_ext_factor = config["yarn_ext_factor"].as<float>();
if (config["yarn_attn_factor"]) params.yarn_attn_factor = config["yarn_attn_factor"].as<float>();
if (config["yarn_beta_fast"]) params.yarn_beta_fast = config["yarn_beta_fast"].as<float>();
if (config["yarn_beta_slow"]) params.yarn_beta_slow = config["yarn_beta_slow"].as<float>();

// Boolean parameters
if (config["interactive"]) params.interactive = config["interactive"].as<bool>();
if (config["interactive_first"]) params.interactive_first = config["interactive_first"].as<bool>();
if (config["conversation"]) {
params.conversation_mode = config["conversation"].as<bool>() ?
COMMON_CONVERSATION_MODE_ENABLED : COMMON_CONVERSATION_MODE_DISABLED;
}
if (config["use_color"]) params.use_color = config["use_color"].as<bool>();
if (config["simple_io"]) params.simple_io = config["simple_io"].as<bool>();
if (config["embedding"]) params.embedding = config["embedding"].as<bool>();
if (config["escape"]) params.escape = config["escape"].as<bool>();
if (config["multiline_input"]) params.multiline_input = config["multiline_input"].as<bool>();
if (config["cont_batching"]) params.cont_batching = config["cont_batching"].as<bool>();
if (config["flash_attn"]) {
params.flash_attn_type = config["flash_attn"].as<bool>() ?
LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
}
if (config["no_perf"]) params.no_perf = config["no_perf"].as<bool>();
if (config["ctx_shift"]) params.ctx_shift = config["ctx_shift"].as<bool>();
if (config["input_prefix_bos"]) params.input_prefix_bos = config["input_prefix_bos"].as<bool>();
if (config["use_mmap"]) params.use_mmap = config["use_mmap"].as<bool>();
if (config["use_mlock"]) params.use_mlock = config["use_mlock"].as<bool>();
if (config["verbose_prompt"]) params.verbose_prompt = config["verbose_prompt"].as<bool>();
if (config["display_prompt"]) params.display_prompt = config["display_prompt"].as<bool>();
if (config["no_kv_offload"]) params.no_kv_offload = config["no_kv_offload"].as<bool>();
if (config["warmup"]) params.warmup = config["warmup"].as<bool>();
if (config["check_tensors"]) params.check_tensors = config["check_tensors"].as<bool>();

// CPU parameters
if (config["cpuparams"]) {
const auto & cpu_config = config["cpuparams"];
if (cpu_config["n_threads"]) params.cpuparams.n_threads = cpu_config["n_threads"].as<int>();
if (cpu_config["strict_cpu"]) params.cpuparams.strict_cpu = cpu_config["strict_cpu"].as<bool>();
if (cpu_config["poll"]) params.cpuparams.poll = cpu_config["poll"].as<uint32_t>();
}

// Sampling parameters
if (config["sampling"]) {
const auto & sampling_config = config["sampling"];
if (sampling_config["seed"]) params.sampling.seed = sampling_config["seed"].as<uint32_t>();
if (sampling_config["n_prev"]) params.sampling.n_prev = sampling_config["n_prev"].as<int32_t>();
if (sampling_config["n_probs"]) params.sampling.n_probs = sampling_config["n_probs"].as<int32_t>();
if (sampling_config["min_keep"]) params.sampling.min_keep = sampling_config["min_keep"].as<int32_t>();
if (sampling_config["top_k"]) params.sampling.top_k = sampling_config["top_k"].as<int32_t>();
if (sampling_config["top_p"]) params.sampling.top_p = sampling_config["top_p"].as<float>();
if (sampling_config["min_p"]) params.sampling.min_p = sampling_config["min_p"].as<float>();
if (sampling_config["xtc_probability"]) params.sampling.xtc_probability = sampling_config["xtc_probability"].as<float>();
if (sampling_config["xtc_threshold"]) params.sampling.xtc_threshold = sampling_config["xtc_threshold"].as<float>();
if (sampling_config["typ_p"]) params.sampling.typ_p = sampling_config["typ_p"].as<float>();
if (sampling_config["temp"]) params.sampling.temp = sampling_config["temp"].as<float>();
if (sampling_config["dynatemp_range"]) params.sampling.dynatemp_range = sampling_config["dynatemp_range"].as<float>();
if (sampling_config["dynatemp_exponent"]) params.sampling.dynatemp_exponent = sampling_config["dynatemp_exponent"].as<float>();
if (sampling_config["penalty_last_n"]) params.sampling.penalty_last_n = sampling_config["penalty_last_n"].as<int32_t>();
if (sampling_config["penalty_repeat"]) params.sampling.penalty_repeat = sampling_config["penalty_repeat"].as<float>();
if (sampling_config["penalty_freq"]) params.sampling.penalty_freq = sampling_config["penalty_freq"].as<float>();
if (sampling_config["penalty_present"]) params.sampling.penalty_present = sampling_config["penalty_present"].as<float>();
if (sampling_config["dry_multiplier"]) params.sampling.dry_multiplier = sampling_config["dry_multiplier"].as<float>();
if (sampling_config["dry_base"]) params.sampling.dry_base = sampling_config["dry_base"].as<float>();
if (sampling_config["dry_allowed_length"]) params.sampling.dry_allowed_length = sampling_config["dry_allowed_length"].as<int32_t>();
if (sampling_config["dry_penalty_last_n"]) params.sampling.dry_penalty_last_n = sampling_config["dry_penalty_last_n"].as<int32_t>();
if (sampling_config["mirostat"]) params.sampling.mirostat = sampling_config["mirostat"].as<int32_t>();
if (sampling_config["top_n_sigma"]) params.sampling.top_n_sigma = sampling_config["top_n_sigma"].as<float>();
if (sampling_config["mirostat_tau"]) params.sampling.mirostat_tau = sampling_config["mirostat_tau"].as<float>();
if (sampling_config["mirostat_eta"]) params.sampling.mirostat_eta = sampling_config["mirostat_eta"].as<float>();
if (sampling_config["ignore_eos"]) params.sampling.ignore_eos = sampling_config["ignore_eos"].as<bool>();
if (sampling_config["no_perf"]) params.sampling.no_perf = sampling_config["no_perf"].as<bool>();
if (sampling_config["timing_per_token"]) params.sampling.timing_per_token = sampling_config["timing_per_token"].as<bool>();
if (sampling_config["grammar"]) params.sampling.grammar = sampling_config["grammar"].as<std::string>();
if (sampling_config["grammar_lazy"]) params.sampling.grammar_lazy = sampling_config["grammar_lazy"].as<bool>();

if (sampling_config["dry_sequence_breakers"]) {
params.sampling.dry_sequence_breakers.clear();
for (const auto & breaker : sampling_config["dry_sequence_breakers"]) {
params.sampling.dry_sequence_breakers.push_back(breaker.as<std::string>());
}
}
}

// Speculative parameters
if (config["speculative"]) {
const auto & spec_config = config["speculative"];
if (spec_config["n_ctx"]) params.speculative.n_ctx = spec_config["n_ctx"].as<int32_t>();
if (spec_config["n_max"]) params.speculative.n_max = spec_config["n_max"].as<int32_t>();
if (spec_config["n_min"]) params.speculative.n_min = spec_config["n_min"].as<int32_t>();
if (spec_config["n_gpu_layers"]) params.speculative.n_gpu_layers = spec_config["n_gpu_layers"].as<int32_t>();
if (spec_config["p_split"]) params.speculative.p_split = spec_config["p_split"].as<float>();
if (spec_config["p_min"]) params.speculative.p_min = spec_config["p_min"].as<float>();

if (spec_config["model"]) {
const auto & model_config = spec_config["model"];
if (model_config["path"]) params.speculative.model.path = model_config["path"].as<std::string>();
if (model_config["url"]) params.speculative.model.url = model_config["url"].as<std::string>();
if (model_config["hf_repo"]) params.speculative.model.hf_repo = model_config["hf_repo"].as<std::string>();
if (model_config["hf_file"]) params.speculative.model.hf_file = model_config["hf_file"].as<std::string>();
}
}

if (config["antiprompt"]) {
params.antiprompt.clear();
for (const auto & antiprompt : config["antiprompt"]) {
params.antiprompt.push_back(antiprompt.as<std::string>());
}
}

return true;
} catch (const YAML::Exception & e) {
LOG_ERR("Error parsing YAML config file '%s': %s\n", config_file.c_str(), e.what());
return false;
} catch (const std::exception & e) {
LOG_ERR("Error loading YAML config file '%s': %s\n", config_file.c_str(), e.what());
return false;
}
}

common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
this->examples = std::move(examples);
return *this;
Expand Down Expand Up @@ -228,8 +389,7 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma

CURLcode res = curl_easy_perform(curl);
if (res == CURLE_OK) {
return true;
}
return true;}

int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
Expand Down Expand Up @@ -1227,6 +1387,20 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
ctx_arg.params = params_org;
return false;
}

// Load YAML config if specified
if (!ctx_arg.params.config_file.empty()) {
if (!common_params_load_from_yaml(ctx_arg.params.config_file, ctx_arg.params)) {
ctx_arg.params = params_org;
return false;
}

if (!common_params_parse_ex(argc, argv, ctx_arg)) {
ctx_arg.params = params_org;
return false;
}
}

if (ctx_arg.params.usage) {
common_params_print_usage(ctx_arg);
if (ctx_arg.print_usage) {
Expand Down Expand Up @@ -1317,6 +1491,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.completion = true;
}
));
add_opt(common_arg(
{"--config"}, "FNAME",
"path to a YAML config file (default: none)",
[](common_params & params, const std::string & value) {
params.config_file = value;
}
));
add_opt(common_arg(
{"--verbose-prompt"},
string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"),
Expand Down
3 changes: 3 additions & 0 deletions common/arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ struct common_params_context {
// if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message)
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);

// load parameters from YAML config file
bool common_params_load_from_yaml(const std::string & config_file, common_params & params);

// function to be used by test-arg-parser
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
bool common_has_curl();
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ struct common_params {
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT
std::string config_file = ""; // path to YAML config file // NOLINT

std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
Expand Down
106 changes: 106 additions & 0 deletions examples/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@

model:
path: "models/my-model.gguf"

n_predict: 128
n_ctx: 4096
n_batch: 512
n_ubatch: 512
n_keep: 0
n_gpu_layers: 32
main_gpu: 0
verbosity: 1

prompt: "Hello, how are you?"
system_prompt: "You are a helpful assistant."
input_prefix: "User: "
input_suffix: "\nAssistant: "

rope_freq_base: 10000.0
rope_freq_scale: 1.0
yarn_ext_factor: 1.0
yarn_attn_factor: 1.0
yarn_beta_fast: 32.0
yarn_beta_slow: 1.0

interactive: false
interactive_first: false
conversation: false
use_color: true
simple_io: false
embedding: false
escape: true
multiline_input: false
cont_batching: true
flash_attn: false
no_perf: false
ctx_shift: true
input_prefix_bos: false
logits_all: false
use_mmap: true
use_mlock: false
verbose_prompt: false
display_prompt: true
dump_kv_cache: false
no_kv_offload: false
warmup: true
check_tensors: false

cache_type_k: "f16"
cache_type_v: "f16"

cpuparams:
n_threads: 8
strict_cpu: false
poll: 50

sampling:
seed: 42
n_prev: 64
n_probs: 0
min_keep: 0
top_k: 40
top_p: 0.95
min_p: 0.05
xtc_probability: 0.0
xtc_threshold: 0.1
typ_p: 1.0
temp: 0.8
dynatemp_range: 0.0
dynatemp_exponent: 1.0
penalty_last_n: 64
penalty_repeat: 1.0
penalty_freq: 0.0
penalty_present: 0.0
dry_multiplier: 0.0
dry_base: 1.75
dry_allowed_length: 2
dry_penalty_last_n: -1
mirostat: 0
top_n_sigma: 0.0
mirostat_tau: 5.0
mirostat_eta: 0.1
ignore_eos: false
no_perf: false
timing_per_token: false
grammar: ""
grammar_lazy: false
dry_sequence_breakers:
- "\n"
- ":"
- "\""

speculative:
n_ctx: 0
n_max: 16
n_min: 5
n_gpu_layers: 0
p_split: 0.1
p_min: 0.9
model:
path: ""

antiprompt:
- "User:"
- "Human:"
- "\n\n"
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyll
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
if (NOT WIN32)
llama_build_and_test(test-arg-parser.cpp)
llama_build_and_test(test-yaml-config.cpp)
endif()

if (NOT LLAMA_SANITIZE_ADDRESS)
Expand Down
Loading
Loading