Skip to content

Commit 85ef80c

Browse files
committed
server : use llama_batch_ext
1 parent 17d3658 commit 85ef80c

File tree

1 file changed

+38
-35
lines changed

1 file changed

+38
-35
lines changed

examples/server/server.cpp

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,7 +1215,7 @@ struct server_slot {
12151215
// only used for completion/embedding/infill/rerank
12161216
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
12171217

1218-
llama_batch_ptr batch_spec;
1218+
llama_batch_ext_ptr batch_spec;
12191219

12201220
llama_context * ctx = nullptr;
12211221
llama_context * ctx_dft = nullptr;
@@ -1787,7 +1787,7 @@ struct server_context {
17871787

17881788
llama_context_params cparams_dft;
17891789

1790-
llama_batch_ptr batch;
1790+
llama_batch_ext_ptr batch;
17911791

17921792
bool clean_kv_cache = true;
17931793
bool add_bos_token = true;
@@ -1940,7 +1940,7 @@ struct server_context {
19401940
slot.n_predict = params_base.n_predict;
19411941

19421942
if (model_dft) {
1943-
slot.batch_spec.reset(llama_batch_init(params_base.speculative.n_max + 1, 1));
1943+
slot.batch_spec.reset(llama_batch_ext_init(params_base.speculative.n_max + 1, 1));
19441944

19451945
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
19461946
if (slot.ctx_dft == nullptr) {
@@ -1976,7 +1976,7 @@ struct server_context {
19761976
const int32_t n_batch = llama_n_batch(ctx);
19771977

19781978
// only a single seq_id per token is needed
1979-
batch.reset(llama_batch_init(std::max(n_batch, params_base.n_parallel), 1));
1979+
batch.reset(llama_batch_ext_init(std::max(n_batch, params_base.n_parallel), 1));
19801980
}
19811981

19821982
metrics.init();
@@ -2094,7 +2094,7 @@ struct server_context {
20942094
}
20952095

20962096
if (slot.ctx_dft) {
2097-
slot.batch_spec.reset(llama_batch_init(slot.params.speculative.n_max + 1, 1));
2097+
slot.batch_spec.reset(llama_batch_ext_init(slot.params.speculative.n_max + 1, 1));
20982098
}
20992099

21002100
slot.state = SLOT_STATE_STARTED;
@@ -2402,7 +2402,7 @@ struct server_context {
24022402
queue_results.send(std::move(res));
24032403
}
24042404

2405-
void send_embedding(const server_slot & slot, llama_batch_ptr & batch) {
2405+
void send_embedding(const server_slot & slot, llama_batch_ext_ptr & batch) {
24062406
auto res = std::make_unique<server_task_result_embd>();
24072407
res->id = slot.id_task;
24082408
res->index = slot.index;
@@ -2413,8 +2413,8 @@ struct server_context {
24132413

24142414
std::vector<float> embd_res(n_embd, 0.0f);
24152415

2416-
for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) {
2417-
llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i);
2416+
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) {
2417+
llama_batch_ext_token_info tok = llama_batch_ext_get_token_info(batch.get(), i);
24182418
if (!tok.logits || tok.seq_id[0] != slot.id) {
24192419
continue;
24202420
}
@@ -2446,14 +2446,14 @@ struct server_context {
24462446
queue_results.send(std::move(res));
24472447
}
24482448

2449-
void send_rerank(const server_slot & slot, llama_batch_ptr & batch) {
2449+
void send_rerank(const server_slot & slot, llama_batch_ext_ptr & batch) {
24502450
auto res = std::make_unique<server_task_result_rerank>();
24512451
res->id = slot.id_task;
24522452
res->index = slot.index;
24532453
res->n_tokens = slot.n_prompt_tokens;
24542454

2455-
for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) {
2456-
llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i);
2455+
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) {
2456+
llama_batch_ext_token_info tok = llama_batch_ext_get_token_info(batch.get(), i);
24572457
if (!tok.logits || tok.seq_id[0] != slot.id) {
24582458
continue;
24592459
}
@@ -2855,7 +2855,7 @@ struct server_context {
28552855
}
28562856

28572857
// start populating the batch for this iteration
2858-
common_batch_clear(batch.get());
2858+
llama_batch_ext_clear(batch.get());
28592859

28602860
// track if given slot can be batched with slots already in the batch
28612861
server_slot * slot_batched = nullptr;
@@ -2877,9 +2877,10 @@ struct server_context {
28772877
continue;
28782878
}
28792879

2880-
slot.i_batch = llama_batch_get_n_tokens(batch.get());
2880+
slot.i_batch = llama_batch_ext_get_n_tokens(batch.get());
28812881

2882-
common_batch_add(batch.get(), slot.sampled, slot.n_past, { slot.id }, true);
2882+
std::array<llama_token, 1> seq_id = { slot.id };
2883+
llama_batch_ext_add_text_token(batch.get(), slot.sampled, slot.n_past, seq_id.data(), seq_id.size(), true);
28832884

28842885
slot.n_past += 1;
28852886

@@ -2896,7 +2897,7 @@ struct server_context {
28962897
int32_t n_ubatch = llama_n_ubatch(ctx);
28972898

28982899
// next, batch any pending prompts without exceeding n_batch
2899-
if (params_base.cont_batching || llama_batch_get_n_tokens(batch.get()) == 0) {
2900+
if (params_base.cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) {
29002901
for (auto & slot : slots) {
29012902
// check if we can batch this slot with the previous one
29022903
if (slot.is_processing()) {
@@ -3062,7 +3063,7 @@ struct server_context {
30623063
// non-causal tasks require to fit the entire prompt in the physical batch
30633064
if (slot.is_non_causal()) {
30643065
// cannot fit the prompt in the current batch - will try next iter
3065-
if (llama_batch_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) {
3066+
if (llama_batch_ext_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) {
30663067
continue;
30673068
}
30683069
}
@@ -3082,11 +3083,12 @@ struct server_context {
30823083
slot.cache_tokens.resize(slot.n_past);
30833084

30843085
// add prompt tokens for processing in the current batch
3085-
while (slot.n_past < slot.n_prompt_tokens && llama_batch_get_n_tokens(batch.get()) < n_batch) {
3086+
while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens(batch.get()) < n_batch) {
30863087
// without pooling, we want to output the embeddings for all the tokens in the batch
30873088
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
30883089

3089-
common_batch_add(batch.get(), prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
3090+
std::array<llama_token, 1> seq_id = { slot.id };
3091+
llama_batch_ext_add_text_token(batch.get(), prompt_tokens[slot.n_past], slot.n_past, seq_id.data(), seq_id.size(), true);
30903092

30913093
if (slot.params.cache_prompt) {
30923094
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@@ -3096,13 +3098,13 @@ struct server_context {
30963098
slot.n_past++;
30973099
}
30983100

3099-
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
3101+
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
31003102

31013103
// entire prompt has been processed
31023104
if (slot.n_past == slot.n_prompt_tokens) {
31033105
slot.state = SLOT_STATE_DONE_PROMPT;
31043106

3105-
GGML_ASSERT(llama_batch_get_n_tokens(batch.get()) > 0);
3107+
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) > 0);
31063108

31073109
common_sampler_reset(slot.smpl);
31083110

@@ -3112,27 +3114,27 @@ struct server_context {
31123114
}
31133115

31143116
// extract the logits only for the last token
3115-
llama_batch_set_logits_last(batch.get());
3117+
llama_batch_ext_set_logits_last(batch.get());
31163118

31173119
slot.n_decoded = 0;
3118-
slot.i_batch = llama_batch_get_n_tokens(batch.get()) - 1;
3120+
slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1;
31193121

3120-
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_get_n_tokens(batch.get()));
3122+
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get()));
31213123
}
31223124
}
31233125

3124-
if (llama_batch_get_n_tokens(batch.get()) >= n_batch) {
3126+
if (llama_batch_ext_get_n_tokens(batch.get()) >= n_batch) {
31253127
break;
31263128
}
31273129
}
31283130
}
31293131

3130-
if (llama_batch_get_n_tokens(batch.get()) == 0) {
3132+
if (llama_batch_ext_get_n_tokens(batch.get()) == 0) {
31313133
SRV_WRN("%s", "no tokens to decode\n");
31323134
return;
31333135
}
31343136

3135-
SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_get_n_tokens(batch.get()));
3137+
SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_ext_get_n_tokens(batch.get()));
31363138

31373139
if (slot_batched) {
31383140
// make sure we're in the right embedding mode
@@ -3142,12 +3144,12 @@ struct server_context {
31423144
}
31433145

31443146
// process the created batch of tokens
3145-
for (int32_t i = 0; i < llama_batch_get_n_tokens(batch.get()); i += n_batch) {
3146-
const int32_t n_tokens = std::min(n_batch, llama_batch_get_n_tokens(batch.get()) - i);
3147+
for (int32_t i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i += n_batch) {
3148+
const int32_t n_tokens = std::min(n_batch, llama_batch_ext_get_n_tokens(batch.get()) - i);
31473149

3148-
llama_batch_ptr batch_view(llama_batch_get_view(batch.get(), i, n_tokens));
3150+
llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens));
31493151

3150-
const int ret = llama_decode(ctx, batch_view.get());
3152+
const int ret = llama_text_decode(ctx, batch_view.get());
31513153
metrics.on_decoded(slots);
31523154

31533155
if (ret != 0) {
@@ -3282,16 +3284,17 @@ struct server_context {
32823284
}
32833285

32843286
// construct the speculation batch
3285-
common_batch_clear(slot.batch_spec.get());
3286-
common_batch_add (slot.batch_spec.get(), id, slot.n_past, { slot.id }, true);
3287+
llama_batch_ext_clear(slot.batch_spec.get());
3288+
std::array<llama_token, 1> seq_id = { slot.id };
3289+
llama_batch_ext_add_text_token(slot.batch_spec.get(), id, slot.n_past, seq_id.data(), seq_id.size(), true);
32873290

32883291
for (size_t i = 0; i < draft.size(); ++i) {
3289-
common_batch_add(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, { slot.id }, true);
3292+
llama_batch_ext_add_text_token(slot.batch_spec.get(), draft[i], slot.n_past + 1, seq_id.data(), seq_id.size(), true);
32903293
}
32913294

3292-
SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_get_n_tokens(slot.batch_spec.get()));
3295+
SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get()));
32933296

3294-
llama_decode(ctx, slot.batch_spec.get());
3297+
llama_text_decode(ctx, slot.batch_spec.get());
32953298

32963299
// the accepted tokens from the speculation
32973300
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);

0 commit comments

Comments
 (0)