Skip to content

Commit 5b7f5fe

Browse files
committed
server : fix segfault on long system prompt (#8987)
* server : fix segfault on long system prompt * server : fix parallel generation with very small batch sizes * server : fix typo in comment Author : Compilade
1 parent a13ec57 commit 5b7f5fe

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

examples/server/server.cpp

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -919,13 +919,13 @@ struct server_context {
919919
default_generation_settings_for_props = get_formated_generation(slots.front());
920920
default_generation_settings_for_props["seed"] = -1;
921921

922-
// the update_slots() logic will always submit a maximum of n_batch tokens
922+
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
923923
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
924924
{
925925
const int32_t n_batch = llama_n_batch(ctx);
926926

927927
// only a single seq_id per token is needed
928-
batch = llama_batch_init(n_batch, 0, 1);
928+
batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
929929
}
930930

931931
metrics.init();
@@ -1345,28 +1345,19 @@ struct server_context {
13451345
if (!system_prompt.empty()) {
13461346
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
13471347

1348-
llama_batch_clear(batch);
1348+
const int32_t n_batch = llama_n_batch(ctx);
1349+
const int32_t n_tokens_prompt = system_tokens.size();
13491350

1350-
for (int i = 0; i < (int)system_tokens.size(); ++i) {
1351-
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
1352-
}
1351+
for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
1352+
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
13531353

1354-
const int32_t n_batch = llama_n_batch(ctx);
1354+
llama_batch_clear(batch);
13551355

1356-
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
1357-
const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
1358-
llama_batch batch_view = {
1359-
n_tokens,
1360-
batch.token + i,
1361-
nullptr,
1362-
batch.pos + i,
1363-
batch.n_seq_id + i,
1364-
batch.seq_id + i,
1365-
batch.logits + i,
1366-
0, 0, 0, // unused
1367-
};
1356+
for (int32_t j = 0; j < n_tokens; ++j) {
1357+
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
1358+
}
13681359

1369-
if (llama_decode(ctx, batch_view) != 0) {
1360+
if (llama_decode(ctx, batch) != 0) {
13701361
LOG_ERROR("llama_decode() failed", {});
13711362
return;
13721363
}

0 commit comments

Comments
 (0)