Skip to content

Commit 97aa444

Browse files
committed
Add missing llama.cpp sampling arguments --min-p and --top-k
Resolves issue #715 by adding support for sampling arguments that were available in llama.cpp but missing in llamafile. Changes: - Add FLAG_min_p and FLAG_top_k declarations and parsing - Add JSON API parameter support with validation - Connect command line flags to sampling defaults - Support both CLI and API usage Note: --samplers requires complex parsing, noted for future work.
1 parent cfa861a commit 97aa444

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

llamafile/flags.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@ float FLAG_presence_penalty = 0;
7272
float FLAG_reserve_tokens = .15;
7373
float FLAG_temperature = .8;
7474
float FLAG_top_p = .95;
75+
float FLAG_min_p = 0.05;
7576
int FLAG_batch = 256;
7677
int FLAG_ctx_size = 8192;
7778
int FLAG_decay_delay = 60 * 5;
7879
int FLAG_flash_attn = false;
7980
int FLAG_gpu = 0;
81+
int FLAG_top_k = 40;
8082
int FLAG_http_ibuf_size = 5 * 1024 * 1024;
8183
int FLAG_http_obuf_size = 1024 * 1024;
8284
int FLAG_keepalive = 5;
@@ -371,6 +373,20 @@ void llamafile_get_flags(int argc, char **argv) {
371373
continue;
372374
}
373375

376+
if (!strcmp(flag, "--min-p")) {
377+
if (i == argc)
378+
missing("--min-p");
379+
FLAG_min_p = atof(argv[i++]);
380+
continue;
381+
}
382+
383+
if (!strcmp(flag, "--top-k")) {
384+
if (i == argc)
385+
missing("--top-k");
386+
FLAG_top_k = strtol(argv[i++], 0, 0);
387+
continue;
388+
}
389+
374390
if (!strcmp(flag, "--frequency-penalty")) {
375391
if (i == argc)
376392
missing("--frequency-penalty");

llamafile/llamafile.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,13 @@ extern float FLAG_presence_penalty;
4242
extern float FLAG_reserve_tokens;
4343
extern float FLAG_temperature;
4444
extern float FLAG_top_p;
45+
extern float FLAG_min_p;
4546
extern int FLAG_batch;
4647
extern int FLAG_ctx_size;
4748
extern int FLAG_decay_delay;
4849
extern int FLAG_flash_attn;
4950
extern int FLAG_gpu;
51+
extern int FLAG_top_k;
5052
extern int FLAG_gpu;
5153
extern int FLAG_http_ibuf_size;
5254
extern int FLAG_http_obuf_size;

llamafile/server/v1_completions.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "client.h"
1919
#include "llama.cpp/llama.h"
2020
#include "llama.cpp/sampling.h"
21+
#include "llamafile/llamafile.h"
2122
#include "llamafile/json.h"
2223
#include "llamafile/llama.h"
2324
#include "llamafile/macros.h"
@@ -48,9 +49,11 @@ struct V1CompletionParams
4849
bool stream = false;
4950
bool stream_include_usage = false;
5051
long max_tokens = -1;
51-
long seed = _rand64();
52-
double top_p = 1;
53-
double temperature = 1;
52+
long seed = FLAG_seed;
53+
double top_p = FLAG_top_p;
54+
double min_p = FLAG_min_p;
55+
long top_k = FLAG_top_k;
56+
double temperature = FLAG_temperature;
5457
double presence_penalty = 0;
5558
double frequency_penalty = 0;
5659
std::string user;
@@ -147,6 +150,8 @@ create_sampler(const V1CompletionParams* params)
147150
llama_sampling_params sparams;
148151
sparams.temp = params->temperature;
149152
sparams.top_p = params->top_p;
153+
sparams.min_p = params->min_p;
154+
sparams.top_k = params->top_k;
150155
sparams.penalty_freq = params->frequency_penalty;
151156
sparams.penalty_present = params->presence_penalty;
152157
sparams.seed = params->seed;
@@ -305,6 +310,26 @@ Client::get_v1_completions_params(V1CompletionParams* params)
305310
return send_error(400, "top_p must be between 0 and 1");
306311
}
307312

313+
// min_p: number|null
314+
Json& min_p = json["min_p"];
315+
if (!min_p.isNull()) {
316+
if (!min_p.isNumber())
317+
return send_error(400, "min_p must be number");
318+
params->min_p = min_p.getNumber();
319+
if (!(0 <= params->min_p && params->min_p <= 1))
320+
return send_error(400, "min_p must be between 0 and 1");
321+
}
322+
323+
// top_k: integer|null
324+
Json& top_k = json["top_k"];
325+
if (!top_k.isNull()) {
326+
if (!top_k.isNumber())
327+
return send_error(400, "top_k must be number");
328+
params->top_k = top_k.getNumber();
329+
if (params->top_k < 0)
330+
return send_error(400, "top_k must be non-negative");
331+
}
332+
308333
// temperature: number|null
309334
//
310335
// What sampling temperature to use, between 0 and 2. Higher values

0 commit comments

Comments
 (0)