From 97aa444ac672c59fcfc4469efcf830848f6346ab Mon Sep 17 00:00:00 2001 From: Anivar A Aravind Date: Mon, 21 Jul 2025 04:31:55 +0530 Subject: [PATCH] Add missing llama.cpp sampling arguments --min-p and --top-k Resolves issue #715 by adding support for sampling arguments that were available in llama.cpp but missing in llamafile. Changes: - Add FLAG_min_p and FLAG_top_k declarations and parsing - Add JSON API parameter support with validation - Connect command line flags to sampling defaults - Support both CLI and API usage Note: --samplers requires complex parsing, noted for future work. --- llamafile/flags.cpp | 16 +++++++++++++++ llamafile/llamafile.h | 2 ++ llamafile/server/v1_completions.cpp | 31 ++++++++++++++++++++++++++--- 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/llamafile/flags.cpp b/llamafile/flags.cpp index c0e3bb3b74..87b883b25c 100644 --- a/llamafile/flags.cpp +++ b/llamafile/flags.cpp @@ -72,11 +72,13 @@ float FLAG_presence_penalty = 0; float FLAG_reserve_tokens = .15; float FLAG_temperature = .8; float FLAG_top_p = .95; +float FLAG_min_p = 0.05; int FLAG_batch = 256; int FLAG_ctx_size = 8192; int FLAG_decay_delay = 60 * 5; int FLAG_flash_attn = false; int FLAG_gpu = 0; +int FLAG_top_k = 40; int FLAG_http_ibuf_size = 5 * 1024 * 1024; int FLAG_http_obuf_size = 1024 * 1024; int FLAG_keepalive = 5; @@ -371,6 +373,20 @@ void llamafile_get_flags(int argc, char **argv) { continue; } + if (!strcmp(flag, "--min-p")) { + if (i == argc) + missing("--min-p"); + FLAG_min_p = atof(argv[i++]); + continue; + } + + if (!strcmp(flag, "--top-k")) { + if (i == argc) + missing("--top-k"); + FLAG_top_k = strtol(argv[i++], 0, 0); + continue; + } + if (!strcmp(flag, "--frequency-penalty")) { if (i == argc) missing("--frequency-penalty"); diff --git a/llamafile/llamafile.h b/llamafile/llamafile.h index b74dda60dd..5e23129d2d 100644 --- a/llamafile/llamafile.h +++ b/llamafile/llamafile.h @@ -42,11 +42,13 @@ extern float FLAG_presence_penalty; extern float FLAG_reserve_tokens; extern float FLAG_temperature; extern float FLAG_top_p; +extern float FLAG_min_p; extern int FLAG_batch; extern int FLAG_ctx_size; extern int FLAG_decay_delay; extern int FLAG_flash_attn; extern int FLAG_gpu; +extern int FLAG_top_k; extern int FLAG_gpu; extern int FLAG_http_ibuf_size; extern int FLAG_http_obuf_size; diff --git a/llamafile/server/v1_completions.cpp b/llamafile/server/v1_completions.cpp index f5294d1307..9d18529bc6 100644 --- a/llamafile/server/v1_completions.cpp +++ b/llamafile/server/v1_completions.cpp @@ -18,6 +18,7 @@ #include "client.h" #include "llama.cpp/llama.h" #include "llama.cpp/sampling.h" +#include "llamafile/llamafile.h" #include "llamafile/json.h" #include "llamafile/llama.h" #include "llamafile/macros.h" @@ -48,9 +49,11 @@ struct V1CompletionParams bool stream = false; bool stream_include_usage = false; long max_tokens = -1; - long seed = _rand64(); - double top_p = 1; - double temperature = 1; + long seed = FLAG_seed; + double top_p = FLAG_top_p; + double min_p = FLAG_min_p; + long top_k = FLAG_top_k; + double temperature = FLAG_temperature; double presence_penalty = 0; double frequency_penalty = 0; std::string user; @@ -147,6 +150,8 @@ create_sampler(const V1CompletionParams* params) llama_sampling_params sparams; sparams.temp = params->temperature; sparams.top_p = params->top_p; + sparams.min_p = params->min_p; + sparams.top_k = params->top_k; sparams.penalty_freq = params->frequency_penalty; sparams.penalty_present = params->presence_penalty; sparams.seed = params->seed; @@ -305,6 +310,26 @@ Client::get_v1_completions_params(V1CompletionParams* params) return send_error(400, "top_p must be between 0 and 1"); } + // min_p: number|null + Json& min_p = json["min_p"]; + if (!min_p.isNull()) { + if (!min_p.isNumber()) + return send_error(400, "min_p must be number"); + params->min_p = min_p.getNumber(); + if (!(0 <= params->min_p && params->min_p <= 1)) + return send_error(400, "min_p must be between 0 and 1"); + } + + // top_k: integer|null + Json& top_k = json["top_k"]; + if (!top_k.isNull()) { + if (!top_k.isNumber()) + return send_error(400, "top_k must be number"); + params->top_k = top_k.getNumber(); + if (params->top_k < 0) + return send_error(400, "top_k must be non-negative"); + } + // temperature: number|null // // What sampling temperature to use, between 0 and 2. Higher values