Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llamafile/flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ float FLAG_presence_penalty = 0;
float FLAG_reserve_tokens = .15;
float FLAG_temperature = .8;
float FLAG_top_p = .95;
float FLAG_min_p = 0.05;
int FLAG_batch = 256;
int FLAG_ctx_size = 8192;
int FLAG_decay_delay = 60 * 5;
int FLAG_flash_attn = false;
int FLAG_gpu = 0;
int FLAG_top_k = 40;
int FLAG_http_ibuf_size = 5 * 1024 * 1024;
int FLAG_http_obuf_size = 1024 * 1024;
int FLAG_keepalive = 5;
Expand Down Expand Up @@ -371,6 +373,20 @@ void llamafile_get_flags(int argc, char **argv) {
continue;
}

if (!strcmp(flag, "--min-p")) {
if (i == argc)
missing("--min-p");
FLAG_min_p = atof(argv[i++]);
continue;
}

if (!strcmp(flag, "--top-k")) {
if (i == argc)
missing("--top-k");
FLAG_top_k = strtol(argv[i++], 0, 0);
continue;
}

if (!strcmp(flag, "--frequency-penalty")) {
if (i == argc)
missing("--frequency-penalty");
Expand Down
2 changes: 2 additions & 0 deletions llamafile/llamafile.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ extern float FLAG_presence_penalty;
extern float FLAG_reserve_tokens;
extern float FLAG_temperature;
extern float FLAG_top_p;
extern float FLAG_min_p;
extern int FLAG_batch;
extern int FLAG_ctx_size;
extern int FLAG_decay_delay;
extern int FLAG_flash_attn;
extern int FLAG_gpu;
extern int FLAG_top_k;
extern int FLAG_gpu;
extern int FLAG_http_ibuf_size;
extern int FLAG_http_obuf_size;
Expand Down
31 changes: 28 additions & 3 deletions llamafile/server/v1_completions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "client.h"
#include "llama.cpp/llama.h"
#include "llama.cpp/sampling.h"
#include "llamafile/llamafile.h"
#include "llamafile/json.h"
#include "llamafile/llama.h"
#include "llamafile/macros.h"
Expand Down Expand Up @@ -48,9 +49,11 @@ struct V1CompletionParams
bool stream = false;
bool stream_include_usage = false;
long max_tokens = -1;
long seed = _rand64();
double top_p = 1;
double temperature = 1;
long seed = FLAG_seed;
double top_p = FLAG_top_p;
double min_p = FLAG_min_p;
long top_k = FLAG_top_k;
double temperature = FLAG_temperature;
double presence_penalty = 0;
double frequency_penalty = 0;
std::string user;
Expand Down Expand Up @@ -147,6 +150,8 @@ create_sampler(const V1CompletionParams* params)
llama_sampling_params sparams;
sparams.temp = params->temperature;
sparams.top_p = params->top_p;
sparams.min_p = params->min_p;
sparams.top_k = params->top_k;
sparams.penalty_freq = params->frequency_penalty;
sparams.penalty_present = params->presence_penalty;
sparams.seed = params->seed;
Expand Down Expand Up @@ -305,6 +310,26 @@ Client::get_v1_completions_params(V1CompletionParams* params)
return send_error(400, "top_p must be between 0 and 1");
}

// min_p: number|null
Json& min_p = json["min_p"];
if (!min_p.isNull()) {
if (!min_p.isNumber())
return send_error(400, "min_p must be number");
params->min_p = min_p.getNumber();
if (!(0 <= params->min_p && params->min_p <= 1))
return send_error(400, "min_p must be between 0 and 1");
}

// top_k: integer|null
Json& top_k = json["top_k"];
if (!top_k.isNull()) {
if (!top_k.isNumber())
return send_error(400, "top_k must be number");
params->top_k = top_k.getNumber();
if (params->top_k < 0)
return send_error(400, "top_k must be non-negative");
}

// temperature: number|null
//
// What sampling temperature to use, between 0 and 2. Higher values
Expand Down