Skip to content

Commit 012f880

Browse files
committed
server : support time limit for generation phase
1 parent 59777c6 commit 012f880

File tree

3 files changed

+60
-32
lines changed

3 files changed

+60
-32
lines changed

examples/llama.vim

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ highlight llama_hl_hint guifg=#ff772f
1212
highlight llama_hl_info guifg=#77ff2f
1313

1414
let s:default_config = {
15-
\ 'endpoint': 'http://127.0.0.1:8012/infill',
16-
\ 'n_prefix': 128,
17-
\ 'n_suffix': 128,
18-
\ 'n_predict': 64,
19-
\ 'n_probs': 3,
20-
\ 'temperature': 0.1,
21-
\ 'auto_fim': v:true,
22-
\ 'stop': ["\n"]
15+
\ 'endpoint': 'http://127.0.0.1:8012/infill',
16+
\ 'n_prefix': 128,
17+
\ 'n_suffix': 128,
18+
\ 'n_predict': 64,
19+
\ 't_max_prompt_ms': 300,
20+
\ 't_max_predict_ms': 200,
21+
\ 'auto_fim': v:true,
22+
\ 'stop': ["\n"]
2323
\ }
2424

2525
let g:llama_config = get(g:, 'llama_config', s:default_config)
@@ -48,6 +48,8 @@ function! llama#init()
4848
autocmd!
4949
autocmd InsertEnter * inoremap <buffer> <silent> <C-F> <C-O>:call llama#fim(v:false)<CR>
5050
autocmd InsertLeave * call llama#fim_cancel()
51+
52+
autocmd CursorMoved * call llama#fim_cancel()
5153
augroup END
5254

5355
silent! call llama#fim_cancel()
@@ -85,19 +87,20 @@ function! llama#fim(is_auto) abort
8587
\ . "\n"
8688

8789
let l:request = json_encode({
88-
\ 'prompt': "",
89-
\ 'input_prefix': l:prefix,
90-
\ 'input_suffix': l:suffix,
91-
"\ 'stop': g:llama_config.stop,
92-
\ 'n_predict': g:llama_config.n_predict,
93-
"\ 'n_probs': g:llama_config.n_probs,
94-
\ 'penalty_last_n': 0,
95-
\ 'temperature': g:llama_config.temperature,
96-
\ 'top_k': 5,
97-
\ 'infill_p': 0.20,
98-
\ 'infill_p_eog': 0.001,
99-
\ 'stream': v:false,
100-
\ 'samplers': ["top_k", "infill"]
90+
\ 'prompt': "",
91+
\ 'input_prefix': l:prefix,
92+
\ 'input_suffix': l:suffix,
93+
"\ 'stop': g:llama_config.stop,
94+
\ 'n_predict': g:llama_config.n_predict,
95+
\ 'penalty_last_n': 0,
96+
\ 'top_k': 5,
97+
\ 'infill_p': 0.20,
98+
\ 'infill_p_eog': 0.001,
99+
\ 'stream': v:false,
100+
\ 'samplers': ["top_k", "infill"],
101+
\ 't_max_prompt_ms': g:llama_config.t_max_prompt_ms,
102+
\ 't_max_predict_ms': g:llama_config.t_max_predict_ms,
103+
\ 'cache_prompt': v:true
101104
\ })
102105

103106
let l:curl_command = printf(
@@ -181,9 +184,9 @@ function! s:fim_on_stdout(job_id, data, event) dict
181184
let l:t_prompt_ms = 1.0
182185
let l:s_prompt = 0
183186

184-
let l:n_gen = 0
185-
let l:t_gen_ms = 1.0
186-
let l:s_gen = 0
187+
let l:n_predict = 0
188+
let l:t_predict_ms = 1.0
189+
let l:s_predict = 0
187190

188191
if s:can_accept && v:shell_error
189192
if !self.is_auto
@@ -221,9 +224,9 @@ function! s:fim_on_stdout(job_id, data, event) dict
221224
let l:t_prompt_ms = get(l:timings, 'prompt_ms', 1)
222225
let l:s_prompt = get(l:timings, 'prompt_per_second', 0)
223226

224-
let l:n_gen = get(l:timings, 'predicted_n', 0)
225-
let l:t_gen_ms = get(l:timings, 'predicted_ms', 1)
226-
let l:s_gen = get(l:timings, 'predicted_per_second', 0)
227+
let l:n_predict = get(l:timings, 'predicted_n', 0)
228+
let l:t_predict_ms = get(l:timings, 'predicted_ms', 1)
229+
let l:s_predict = get(l:timings, 'predicted_per_second', 0)
227230
endif
228231
endif
229232

@@ -256,8 +259,8 @@ function! s:fim_on_stdout(job_id, data, event) dict
256259

257260
let l:info = printf("%s | prompt: %d (%.2f ms, %.2f t/s) | predict: %d (%.2f ms, %.2f t/s) | total: %f.2 ms",
258261
\ l:prefix,
259-
\ l:n_prompt, l:t_prompt_ms, l:s_prompt,
260-
\ l:n_gen, l:t_gen_ms, l:s_gen,
262+
\ l:n_prompt, l:t_prompt_ms, l:s_prompt,
263+
\ l:n_predict, l:t_predict_ms, l:s_predict,
261264
\ 1000.0 * reltimefloat(reltime(s:t_fim_start))
262265
\ )
263266

examples/server/server.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,12 @@ struct slot_params {
128128
bool stream = true;
129129
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
130130

131-
int32_t n_keep = 0; // number of tokens to keep from initial prompt
132-
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
133-
int32_t n_predict = -1; // new tokens to predict
131+
int32_t n_keep = 0; // number of tokens to keep from initial prompt
132+
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
133+
int32_t n_predict = -1; // new tokens to predict
134+
135+
int64_t t_max_prompt_ms = -1;
136+
int64_t t_max_predict_ms = -1;
134137

135138
std::vector<std::string> antiprompt;
136139

@@ -968,6 +971,10 @@ struct server_context {
968971
}
969972
}
970973

974+
// time limits
975+
slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms);
976+
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
977+
971978
{
972979
slot.sparams.logit_bias.clear();
973980

@@ -1183,6 +1190,13 @@ struct server_context {
11831190
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
11841191
}
11851192

1193+
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
1194+
slot.stopped_limit = true;
1195+
slot.has_next_token = false;
1196+
1197+
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
1198+
}
1199+
11861200
// if context shift is disabled, we stop when it reaches the context limit
11871201
if (slot.n_decoded >= slot.n_ctx) {
11881202
slot.truncated = true;
@@ -2004,6 +2018,13 @@ struct server_context {
20042018
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
20052019
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
20062020

2021+
// for now pick context to fit in a single batch
2022+
const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/2);
2023+
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
2024+
2025+
prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
2026+
suffix_tokens.resize(n_suffix_take);
2027+
20072028
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
20082029
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
20092030

src/llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6767,6 +6767,10 @@ static void llm_load_vocab(
67676767
vocab.special_eog_ids.insert(vocab.special_eom_id);
67686768
LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
67696769
}
6770+
6771+
if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_sep_id) == 0) {
6772+
vocab.special_eog_ids.insert(vocab.special_fim_sep_id);
6773+
}
67706774
}
67716775

67726776
// build special tokens cache

0 commit comments

Comments
 (0)