Skip to content

Commit 5e6c7ba

Browse files
committed
abstract out the batch management
1 parent a44029a commit 5e6c7ba

File tree

3 files changed

+147
-88
lines changed

3 files changed

+147
-88
lines changed

examples/llava/mtmd.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
112112
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
113113
}
114114

115-
std::vector<std::string> parts = string_split_str(text.text, ctx->image_marker);
115+
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
116116
output.clear();
117117
output.reserve(parts.size());
118118

@@ -196,18 +196,6 @@ std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
196196
return image_tokens->id;
197197
}
198198

199-
size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
200-
return image_tokens->n_tokens();
201-
}
202-
203-
size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) {
204-
return image_tokens->nx;
205-
}
206-
207-
size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
208-
return image_tokens->ny;
209-
}
210-
211199
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
212200
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
213201
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);

examples/server/server.cpp

Lines changed: 64 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,7 +1859,7 @@ struct server_context {
18591859

18601860
llama_context_params cparams_dft;
18611861

1862-
llama_batch batch = {};
1862+
server_batch batch;
18631863

18641864
bool clean_kv_cache = true;
18651865
bool add_bos_token = true;
@@ -1897,8 +1897,6 @@ struct server_context {
18971897

18981898
llama_batch_free(slot.batch_spec);
18991899
}
1900-
1901-
llama_batch_free(batch);
19021900
}
19031901

19041902
bool load_model(const common_params & params) {
@@ -2035,9 +2033,7 @@ struct server_context {
20352033
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
20362034
{
20372035
const int32_t n_batch = llama_n_batch(ctx);
2038-
2039-
// only a single seq_id per token is needed
2040-
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
2036+
batch = server_batch(std::max(n_batch, params_base.n_parallel));
20412037
}
20422038

20432039
metrics.init();
@@ -2934,7 +2930,7 @@ struct server_context {
29342930
}*/
29352931

29362932
// start populating the batch for this iteration
2937-
common_batch_clear(batch);
2933+
batch.clear();
29382934

29392935
// track if given slot can be batched with slots already in the batch
29402936
server_slot * slot_batched = nullptr;
@@ -2956,9 +2952,9 @@ struct server_context {
29562952
continue;
29572953
}
29582954

2959-
slot.i_batch = batch.n_tokens;
2955+
slot.i_batch = batch.n_tokens();
29602956

2961-
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
2957+
common_batch_add(batch.batch, slot.sampled, slot.n_past, { slot.id }, true);
29622958

29632959
slot.n_past += 1;
29642960

@@ -2974,12 +2970,8 @@ struct server_context {
29742970
int32_t n_batch = llama_n_batch(ctx);
29752971
int32_t n_ubatch = llama_n_ubatch(ctx);
29762972

2977-
// for multimodal
2978-
bool is_decoding_embd = false;
2979-
server_embd_batch batch_embd;
2980-
29812973
// next, batch any pending prompts without exceeding n_batch
2982-
if (params_base.cont_batching || batch.n_tokens == 0) {
2974+
if (params_base.cont_batching || batch.n_tokens() == 0) {
29832975
for (auto & slot : slots) {
29842976
// check if we can batch this slot with the previous one
29852977
if (slot.is_processing()) {
@@ -3147,7 +3139,7 @@ struct server_context {
31473139
// non-causal tasks require to fit the entire prompt in the physical batch
31483140
if (slot.is_non_causal()) {
31493141
// cannot fit the prompt in the current batch - will try next iter
3150-
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
3142+
if (batch.n_tokens() + slot.n_prompt_tokens > n_batch) {
31513143
continue;
31523144
}
31533145
}
@@ -3167,36 +3159,55 @@ struct server_context {
31673159
slot.cache_tokens.keep_until(slot.n_past);
31683160

31693161
// add prompt tokens for processing in the current batch
3170-
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
3162+
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens() < n_batch) {
31713163
// without pooling, we want to output the embeddings for all the tokens in the batch
31723164
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
31733165

31743166
auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past);
31753167
if (curr_chunk.tok_image) {
3176-
// decode image
3177-
server_encode_image(slot.mctx, batch_embd, curr_chunk, slot.n_past, slot.id);
3178-
is_decoding_embd = true;
3179-
SLT_INF(slot, "decoding image, n_past = %d, n_tokens = %d\n", slot.n_past, batch_embd.batch.n_tokens);
3180-
slot.n_past += batch_embd.batch.n_tokens;
3181-
break; // do not process any other slots
3168+
// if there are already TEXT tokens in the batch, we need to process them first
3169+
if (batch.batch.n_tokens > 0) {
3170+
break;
3171+
}
3172+
// encode the image
3173+
server_encode_image(slot.mctx, batch, curr_chunk, slot.n_past, slot.id);
3174+
GGML_ASSERT(batch.has_embd());
3175+
SLT_INF(slot, "image encoded, n_past = %d, n_embd_tokens = %d\n", slot.n_past, batch.n_tokens());
3176+
3177+
if (slot.params.cache_prompt) {
3178+
slot.cache_tokens.add_image_tokens(curr_chunk.tok_image);
3179+
}
3180+
3181+
slot.n_past += batch.n_tokens();
3182+
slot.n_prompt_tokens_processed += batch.n_tokens();
3183+
break; // we cannot have both text batch and image batch
3184+
31823185
} else {
3183-
common_batch_add(batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd);
3186+
GGML_ASSERT(!batch.has_embd());
3187+
common_batch_add(batch.batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd);
31843188
if (slot.params.cache_prompt) {
31853189
slot.cache_tokens.add_text_token(curr_chunk.tok_text);
31863190
}
3191+
3192+
slot.n_prompt_tokens_processed++;
3193+
slot.n_past++;
31873194
}
3195+
}
3196+
3197+
SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
31883198

3189-
slot.n_prompt_tokens_processed++;
3190-
slot.n_past++;
3199+
if (batch.has_embd()) {
3200+
// currently, we can only process one image at a time, so we skip other slots
3201+
break;
31913202
}
31923203

3193-
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
3204+
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
31943205

31953206
// entire prompt has been processed
31963207
if (slot.n_past == slot.n_prompt_tokens) {
31973208
slot.state = SLOT_STATE_DONE_PROMPT;
31983209

3199-
GGML_ASSERT(batch.n_tokens > 0);
3210+
GGML_ASSERT(batch.n_tokens() > 0);
32003211

32013212
common_sampler_reset(slot.smpl);
32023213

@@ -3209,27 +3220,32 @@ struct server_context {
32093220
}
32103221

32113222
// extract the logits only for the last token
3212-
batch.logits[batch.n_tokens - 1] = true;
3223+
batch.logits[batch.n_tokens() - 1] = true;
32133224

32143225
slot.n_decoded = 0;
3215-
slot.i_batch = batch.n_tokens - 1;
3226+
slot.i_batch = batch.n_tokens() - 1;
32163227

3217-
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
3228+
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens());
32183229
}
32193230
}
32203231

3221-
if (batch.n_tokens >= n_batch) {
3232+
if (batch.n_tokens() >= n_batch) {
32223233
break;
32233234
}
32243235
}
32253236
}
32263237

3227-
if (batch.n_tokens == 0) {
3238+
if (batch.n_tokens() == 0) {
32283239
SRV_WRN("%s", "no tokens to decode\n");
32293240
return;
32303241
}
32313242

3232-
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
3243+
// debug
3244+
if (batch.has_embd()) {
3245+
SRV_INF("decoding embd batch, n_tokens = %d\n", batch.n_tokens());
3246+
} else {
3247+
SRV_INF("decoding batch, n_tokens = %d\n", batch.n_tokens());
3248+
}
32333249

32343250
if (slot_batched) {
32353251
// make sure we're in the right embedding mode
@@ -3239,28 +3255,29 @@ struct server_context {
32393255
}
32403256

32413257
// process the created batch of tokens
3242-
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
3243-
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
3258+
for (int32_t i = 0; i < batch.n_tokens(); i += n_batch) {
3259+
const int32_t n_tokens = std::min(n_batch, batch.n_tokens() - i);
32443260

3245-
llama_batch batch_view = is_decoding_embd ? batch_embd.batch : llama_batch{
3261+
// TODO @ngxson : hacky here, we don't want to split the embd batch
3262+
llama_batch batch_view = batch.has_embd() ? batch.batch : llama_batch{
32463263
n_tokens,
3247-
batch.token + i,
3264+
batch.batch.token + i,
32483265
nullptr,
3249-
batch.pos + i,
3250-
batch.n_seq_id + i,
3251-
batch.seq_id + i,
3252-
batch.logits + i,
3266+
batch.batch.pos + i,
3267+
batch.batch.n_seq_id + i,
3268+
batch.batch.seq_id + i,
3269+
batch.batch.logits + i,
32533270
};
32543271

32553272
// TODO @ngxson : maybe move this to llama_batch_ext
3256-
if (is_decoding_embd && mtmd_decode_use_non_causal(mctx)) {
3273+
if (batch.has_embd() && mtmd_decode_use_non_causal(mctx)) {
32573274
llama_set_causal_attn(ctx, false);
32583275
}
32593276

32603277
const int ret = llama_decode(ctx, batch_view);
32613278
metrics.on_decoded(slots);
32623279

3263-
if (is_decoding_embd && mtmd_decode_use_non_causal(mctx)) {
3280+
if (batch.has_embd() && mtmd_decode_use_non_causal(mctx)) {
32643281
llama_set_causal_attn(ctx, true);
32653282
}
32663283

@@ -4006,13 +4023,13 @@ int main(int argc, char ** argv) {
40064023
/* add_special */ true,
40074024
/* parse_special */ true,
40084025
};
4009-
mtmd_input_chunks * tokenized = mtmd_tokenize(ctx_server.mctx, inp_txt, bitmaps);
4010-
if (!tokenized) {
4026+
mtmd_input_chunks chunks;
4027+
int32_t tokenized = mtmd_tokenize(ctx_server.mctx, chunks, inp_txt, bitmaps);
4028+
if (tokenized != 0) {
40114029
throw std::runtime_error("Failed to tokenize prompt");
40124030
}
4013-
server_inputs tmp(tokenized);
4031+
server_inputs tmp(chunks);
40144032
inputs.push_back(std::move(tmp));
4015-
mtmd_input_chunks_free(tokenized, false); // only free the container, not the images
40164033
}
40174034
} else {
40184035
// non-multimodal version

0 commit comments

Comments
 (0)