Skip to content

Commit f2e59a8

Browse files
committed
rework, targeting llama-server
1 parent 4ed4fe7 commit f2e59a8

File tree

10 files changed

+190
-135
lines changed

10 files changed

+190
-135
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ examples/server/*.css.hpp
9898
examples/server/*.html.hpp
9999
examples/server/*.js.hpp
100100
examples/server/*.mjs.hpp
101+
examples/server/*.gz.hpp
101102
!build_64.sh
102103
!examples/*.bat
103104
!examples/*/*.kts

common/common.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,7 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam
580580
return buf.str();
581581
}
582582

583+
/*
583584
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
584585
std::stringstream buf;
585586
@@ -614,6 +615,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
614615
615616
return buf.str();
616617
}
618+
*/
617619

618620
void string_process_escapes(std::string & input) {
619621
std::size_t input_len = input.length();
@@ -1608,27 +1610,20 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons
16081610
// Batch utils
16091611
//
16101612

1611-
void common_batch_clear(struct llama_batch & batch) {
1612-
batch.n_tokens = 0;
1613+
void common_batch_clear(struct llama_batch * batch) {
1614+
llama_batch_clear(batch);
16131615
}
16141616

16151617
void common_batch_add(
1616-
struct llama_batch & batch,
1618+
struct llama_batch * batch,
16171619
llama_token id,
16181620
llama_pos pos,
16191621
const std::vector<llama_seq_id> & seq_ids,
16201622
bool logits) {
1621-
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
1622-
1623-
batch.token [batch.n_tokens] = id;
1624-
batch.pos [batch.n_tokens] = pos;
1625-
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
1626-
for (size_t i = 0; i < seq_ids.size(); ++i) {
1627-
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
1623+
int32_t res = llama_batch_add_text_token(batch, id, pos, seq_ids.data(), seq_ids.size(), logits);
1624+
if (res == -1) {
1625+
LOG_ERR("%s: llama_batch size exceeded\n", __func__);
16281626
}
1629-
batch.logits [batch.n_tokens] = logits;
1630-
1631-
batch.n_tokens++;
16321627
}
16331628

16341629
//

common/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,10 +554,10 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap
554554
// Batch utils
555555
//
556556

557-
void common_batch_clear(struct llama_batch & batch);
557+
void common_batch_clear(struct llama_batch * batch);
558558

559559
void common_batch_add(
560-
struct llama_batch & batch,
560+
struct llama_batch * batch,
561561
llama_token id,
562562
llama_pos pos,
563563
const std::vector<llama_seq_id> & seq_ids,

common/speculative.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct common_speculative {
1313
struct llama_context * ctx;
1414
struct common_sampler * smpl;
1515

16-
llama_batch batch;
16+
llama_batch * batch;
1717
llama_tokens prompt;
1818
};
1919

@@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init(
2222
auto * result = new common_speculative {
2323
/* .ctx = */ ctx_dft,
2424
/* .smpl = */ nullptr,
25-
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
25+
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 1),
2626
/* .prompt = */ {},
2727
};
2828

@@ -215,7 +215,7 @@ llama_tokens common_speculative_gen_draft(
215215
}
216216

217217
// we should rarely end-up here during normal decoding
218-
if (batch.n_tokens > 0) {
218+
if (llama_batch_get_n_tokens(batch) > 0) {
219219
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
220220

221221
llama_decode(ctx, batch);

examples/server/server.cpp

Lines changed: 43 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,7 +1215,7 @@ struct server_slot {
12151215
// only used for completion/embedding/infill/rerank
12161216
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
12171217

1218-
llama_batch batch_spec = {};
1218+
llama_batch_ptr batch_spec;
12191219

12201220
llama_context * ctx = nullptr;
12211221
llama_context * ctx_dft = nullptr;
@@ -1787,7 +1787,7 @@ struct server_context {
17871787

17881788
llama_context_params cparams_dft;
17891789

1790-
llama_batch batch = {};
1790+
llama_batch_ptr batch;
17911791

17921792
bool clean_kv_cache = true;
17931793
bool add_bos_token = true;
@@ -1820,11 +1820,7 @@ struct server_context {
18201820

18211821
common_speculative_free(slot.spec);
18221822
slot.spec = nullptr;
1823-
1824-
llama_batch_free(slot.batch_spec);
18251823
}
1826-
1827-
llama_batch_free(batch);
18281824
}
18291825

18301826
bool load_model(const common_params & params) {
@@ -1944,7 +1940,7 @@ struct server_context {
19441940
slot.n_predict = params_base.n_predict;
19451941

19461942
if (model_dft) {
1947-
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
1943+
slot.batch_spec.reset(llama_batch_init(params_base.speculative.n_max + 1, 1));
19481944

19491945
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
19501946
if (slot.ctx_dft == nullptr) {
@@ -1969,7 +1965,7 @@ struct server_context {
19691965

19701966
slot.reset();
19711967

1972-
slots.push_back(slot);
1968+
slots.push_back(std::move(slot));
19731969
}
19741970

19751971
default_generation_settings_for_props = slots[0].to_json();
@@ -1980,7 +1976,7 @@ struct server_context {
19801976
const int32_t n_batch = llama_n_batch(ctx);
19811977

19821978
// only a single seq_id per token is needed
1983-
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
1979+
batch.reset(llama_batch_init(std::max(n_batch, params_base.n_parallel), 1));
19841980
}
19851981

19861982
metrics.init();
@@ -2098,9 +2094,7 @@ struct server_context {
20982094
}
20992095

21002096
if (slot.ctx_dft) {
2101-
llama_batch_free(slot.batch_spec);
2102-
2103-
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
2097+
slot.batch_spec.reset(llama_batch_init(slot.params.speculative.n_max + 1, 1));
21042098
}
21052099

21062100
slot.state = SLOT_STATE_STARTED;
@@ -2408,7 +2402,7 @@ struct server_context {
24082402
queue_results.send(std::move(res));
24092403
}
24102404

2411-
void send_embedding(const server_slot & slot, const llama_batch & batch) {
2405+
void send_embedding(const server_slot & slot, llama_batch_ptr & batch) {
24122406
auto res = std::make_unique<server_task_result_embd>();
24132407
res->id = slot.id_task;
24142408
res->index = slot.index;
@@ -2419,18 +2413,19 @@ struct server_context {
24192413

24202414
std::vector<float> embd_res(n_embd, 0.0f);
24212415

2422-
for (int i = 0; i < batch.n_tokens; ++i) {
2423-
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
2416+
for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) {
2417+
llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i);
2418+
if (!tok.logits || tok.seq_id[0] != slot.id) {
24242419
continue;
24252420
}
24262421

2427-
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
2422+
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]);
24282423
if (embd == NULL) {
24292424
embd = llama_get_embeddings_ith(ctx, i);
24302425
}
24312426

24322427
if (embd == NULL) {
2433-
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
2428+
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]);
24342429

24352430
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
24362431
continue;
@@ -2451,24 +2446,25 @@ struct server_context {
24512446
queue_results.send(std::move(res));
24522447
}
24532448

2454-
void send_rerank(const server_slot & slot, const llama_batch & batch) {
2449+
void send_rerank(const server_slot & slot, llama_batch_ptr & batch) {
24552450
auto res = std::make_unique<server_task_result_rerank>();
24562451
res->id = slot.id_task;
24572452
res->index = slot.index;
24582453
res->n_tokens = slot.n_prompt_tokens;
24592454

2460-
for (int i = 0; i < batch.n_tokens; ++i) {
2461-
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
2455+
for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) {
2456+
llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i);
2457+
if (!tok.logits || tok.seq_id[0] != slot.id) {
24622458
continue;
24632459
}
24642460

2465-
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
2461+
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]);
24662462
if (embd == NULL) {
24672463
embd = llama_get_embeddings_ith(ctx, i);
24682464
}
24692465

24702466
if (embd == NULL) {
2471-
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
2467+
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]);
24722468

24732469
res->score = -1e6;
24742470
continue;
@@ -2859,7 +2855,7 @@ struct server_context {
28592855
}
28602856

28612857
// start populating the batch for this iteration
2862-
common_batch_clear(batch);
2858+
common_batch_clear(batch.get());
28632859

28642860
// track if given slot can be batched with slots already in the batch
28652861
server_slot * slot_batched = nullptr;
@@ -2881,9 +2877,9 @@ struct server_context {
28812877
continue;
28822878
}
28832879

2884-
slot.i_batch = batch.n_tokens;
2880+
slot.i_batch = llama_batch_get_n_tokens(batch.get());
28852881

2886-
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
2882+
common_batch_add(batch.get(), slot.sampled, slot.n_past, { slot.id }, true);
28872883

28882884
slot.n_past += 1;
28892885

@@ -2900,7 +2896,7 @@ struct server_context {
29002896
int32_t n_ubatch = llama_n_ubatch(ctx);
29012897

29022898
// next, batch any pending prompts without exceeding n_batch
2903-
if (params_base.cont_batching || batch.n_tokens == 0) {
2899+
if (params_base.cont_batching || llama_batch_get_n_tokens(batch.get()) == 0) {
29042900
for (auto & slot : slots) {
29052901
// check if we can batch this slot with the previous one
29062902
if (slot.is_processing()) {
@@ -3066,7 +3062,7 @@ struct server_context {
30663062
// non-causal tasks require to fit the entire prompt in the physical batch
30673063
if (slot.is_non_causal()) {
30683064
// cannot fit the prompt in the current batch - will try next iter
3069-
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
3065+
if (llama_batch_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) {
30703066
continue;
30713067
}
30723068
}
@@ -3086,11 +3082,11 @@ struct server_context {
30863082
slot.cache_tokens.resize(slot.n_past);
30873083

30883084
// add prompt tokens for processing in the current batch
3089-
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
3085+
while (slot.n_past < slot.n_prompt_tokens && llama_batch_get_n_tokens(batch.get()) < n_batch) {
30903086
// without pooling, we want to output the embeddings for all the tokens in the batch
30913087
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
30923088

3093-
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
3089+
common_batch_add(batch.get(), prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
30943090

30953091
if (slot.params.cache_prompt) {
30963092
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@@ -3100,13 +3096,13 @@ struct server_context {
31003096
slot.n_past++;
31013097
}
31023098

3103-
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
3099+
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
31043100

31053101
// entire prompt has been processed
31063102
if (slot.n_past == slot.n_prompt_tokens) {
31073103
slot.state = SLOT_STATE_DONE_PROMPT;
31083104

3109-
GGML_ASSERT(batch.n_tokens > 0);
3105+
GGML_ASSERT(llama_batch_get_n_tokens(batch.get()) > 0);
31103106

31113107
common_sampler_reset(slot.smpl);
31123108

@@ -3116,27 +3112,27 @@ struct server_context {
31163112
}
31173113

31183114
// extract the logits only for the last token
3119-
batch.logits[batch.n_tokens - 1] = true;
3115+
llama_batch_set_logits_last(batch.get());
31203116

31213117
slot.n_decoded = 0;
3122-
slot.i_batch = batch.n_tokens - 1;
3118+
slot.i_batch = llama_batch_get_n_tokens(batch.get()) - 1;
31233119

3124-
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
3120+
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_get_n_tokens(batch.get()));
31253121
}
31263122
}
31273123

3128-
if (batch.n_tokens >= n_batch) {
3124+
if (llama_batch_get_n_tokens(batch.get()) >= n_batch) {
31293125
break;
31303126
}
31313127
}
31323128
}
31333129

3134-
if (batch.n_tokens == 0) {
3130+
if (llama_batch_get_n_tokens(batch.get()) == 0) {
31353131
SRV_WRN("%s", "no tokens to decode\n");
31363132
return;
31373133
}
31383134

3139-
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
3135+
SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_get_n_tokens(batch.get()));
31403136

31413137
if (slot_batched) {
31423138
// make sure we're in the right embedding mode
@@ -3146,20 +3142,12 @@ struct server_context {
31463142
}
31473143

31483144
// process the created batch of tokens
3149-
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
3150-
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
3151-
3152-
llama_batch batch_view = {
3153-
n_tokens,
3154-
batch.token + i,
3155-
nullptr,
3156-
batch.pos + i,
3157-
batch.n_seq_id + i,
3158-
batch.seq_id + i,
3159-
batch.logits + i,
3160-
};
3145+
for (int32_t i = 0; i < llama_batch_get_n_tokens(batch.get()); i += n_batch) {
3146+
const int32_t n_tokens = std::min(n_batch, llama_batch_get_n_tokens(batch.get()) - i);
3147+
3148+
llama_batch_ptr batch_view(llama_batch_get_view(batch.get(), i, n_tokens));
31613149

3162-
const int ret = llama_decode(ctx, batch_view);
3150+
const int ret = llama_decode(ctx, batch_view.get());
31633151
metrics.on_decoded(slots);
31643152

31653153
if (ret != 0) {
@@ -3294,16 +3282,16 @@ struct server_context {
32943282
}
32953283

32963284
// construct the speculation batch
3297-
common_batch_clear(slot.batch_spec);
3298-
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
3285+
common_batch_clear(slot.batch_spec.get());
3286+
common_batch_add (slot.batch_spec.get(), id, slot.n_past, { slot.id }, true);
32993287

33003288
for (size_t i = 0; i < draft.size(); ++i) {
3301-
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
3289+
common_batch_add(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, { slot.id }, true);
33023290
}
33033291

3304-
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
3292+
SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_get_n_tokens(slot.batch_spec.get()));
33053293

3306-
llama_decode(ctx, slot.batch_spec);
3294+
llama_decode(ctx, slot.batch_spec.get());
33073295

33083296
// the accepted tokens from the speculation
33093297
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);

include/llama-cpp.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ struct llama_adapter_lora_deleter {
2424
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
2525
};
2626

27+
struct llama_batch_deleter {
28+
void operator()(llama_batch * batch) { llama_batch_free(batch); }
29+
};
30+
2731
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
2832
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
2933
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
3034
typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr;
35+
typedef std::unique_ptr<llama_batch, llama_batch_deleter> llama_batch_ptr;

0 commit comments

Comments
 (0)