|
18 | 18 | #include "client.h" |
19 | 19 | #include "llama.cpp/llama.h" |
20 | 20 | #include "llama.cpp/sampling.h" |
| 21 | +#include "llamafile/llamafile.h" |
21 | 22 | #include "llamafile/json.h" |
22 | 23 | #include "llamafile/llama.h" |
23 | 24 | #include "llamafile/macros.h" |
@@ -48,9 +49,11 @@ struct V1CompletionParams |
48 | 49 | bool stream = false; |
49 | 50 | bool stream_include_usage = false; |
50 | 51 | long max_tokens = -1; |
51 | | - long seed = _rand64(); |
52 | | - double top_p = 1; |
53 | | - double temperature = 1; |
| 52 | + long seed = FLAG_seed; |
| 53 | + double top_p = FLAG_top_p; |
| 54 | + double min_p = FLAG_min_p; |
| 55 | + long top_k = FLAG_top_k; |
| 56 | + double temperature = FLAG_temperature; |
54 | 57 | double presence_penalty = 0; |
55 | 58 | double frequency_penalty = 0; |
56 | 59 | std::string user; |
@@ -147,6 +150,8 @@ create_sampler(const V1CompletionParams* params) |
147 | 150 | llama_sampling_params sparams; |
148 | 151 | sparams.temp = params->temperature; |
149 | 152 | sparams.top_p = params->top_p; |
| 153 | + sparams.min_p = params->min_p; |
| 154 | + sparams.top_k = params->top_k; |
150 | 155 | sparams.penalty_freq = params->frequency_penalty; |
151 | 156 | sparams.penalty_present = params->presence_penalty; |
152 | 157 | sparams.seed = params->seed; |
@@ -305,6 +310,26 @@ Client::get_v1_completions_params(V1CompletionParams* params) |
305 | 310 | return send_error(400, "top_p must be between 0 and 1"); |
306 | 311 | } |
307 | 312 |
|
| 313 | + // min_p: number|null |
| 314 | + Json& min_p = json["min_p"]; |
| 315 | + if (!min_p.isNull()) { |
| 316 | + if (!min_p.isNumber()) |
| 317 | + return send_error(400, "min_p must be number"); |
| 318 | + params->min_p = min_p.getNumber(); |
| 319 | + if (!(0 <= params->min_p && params->min_p <= 1)) |
| 320 | + return send_error(400, "min_p must be between 0 and 1"); |
| 321 | + } |
| 322 | + |
| 323 | + // top_k: integer|null |
| 324 | + Json& top_k = json["top_k"]; |
| 325 | + if (!top_k.isNull()) { |
| 326 | + if (!top_k.isNumber()) |
| 327 | + return send_error(400, "top_k must be number"); |
| 328 | + params->top_k = top_k.getNumber(); |
| 329 | + if (params->top_k < 0) |
| 330 | + return send_error(400, "top_k must be non-negative"); |
| 331 | + } |
| 332 | + |
308 | 333 | // temperature: number|null |
309 | 334 | // |
310 | 335 | // What sampling temperature to use, between 0 and 2. Higher values |
|
0 commit comments