Skip to content

Commit 8afa952

Browse files
committed
allow decoding image embedding to be split into batches
1 parent cd11585 commit 8afa952

File tree

2 files changed

+121
-69
lines changed

2 files changed

+121
-69
lines changed

examples/server/server.cpp

Lines changed: 40 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,7 +1860,8 @@ struct server_context {
18601860

18611861
llama_context_params cparams_dft;
18621862

1863-
server_batch batch;
1863+
llama_batch batch;
1864+
server_batch_embd batch_embd;
18641865

18651866
bool clean_kv_cache = true;
18661867
bool add_bos_token = true;
@@ -1898,6 +1899,8 @@ struct server_context {
18981899

18991900
llama_batch_free(slot.batch_spec);
19001901
}
1902+
1903+
llama_batch_free(batch);
19011904
}
19021905

19031906
bool load_model(const common_params & params) {
@@ -2034,7 +2037,8 @@ struct server_context {
20342037
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
20352038
{
20362039
const int32_t n_batch = llama_n_batch(ctx);
2037-
batch = server_batch(std::max(n_batch, params_base.n_parallel));
2040+
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
2041+
batch_embd = server_batch_embd(std::max(n_batch, params_base.n_parallel));
20382042
}
20392043

20402044
metrics.init();
@@ -2931,7 +2935,7 @@ struct server_context {
29312935
}*/
29322936

29332937
// start populating the batch for this iteration
2934-
batch.clear();
2938+
common_batch_clear(batch);
29352939

29362940
// track if given slot can be batched with slots already in the batch
29372941
server_slot * slot_batched = nullptr;
@@ -2953,9 +2957,9 @@ struct server_context {
29532957
continue;
29542958
}
29552959

2956-
slot.i_batch = batch.n_tokens();
2960+
slot.i_batch = batch.n_tokens;
29572961

2958-
common_batch_add(batch.batch, slot.sampled, slot.n_past, { slot.id }, true);
2962+
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
29592963

29602964
slot.n_past += 1;
29612965

@@ -2972,7 +2976,7 @@ struct server_context {
29722976
int32_t n_ubatch = llama_n_ubatch(ctx);
29732977

29742978
// next, batch any pending prompts without exceeding n_batch
2975-
if (params_base.cont_batching || batch.n_tokens() == 0) {
2979+
if (params_base.cont_batching || batch.n_tokens == 0) {
29762980
for (auto & slot : slots) {
29772981
// check if we can batch this slot with the previous one
29782982
if (slot.is_processing()) {
@@ -3140,7 +3144,7 @@ struct server_context {
31403144
// non-causal tasks require to fit the entire prompt in the physical batch
31413145
if (slot.is_non_causal()) {
31423146
// cannot fit the prompt in the current batch - will try next iter
3143-
if (batch.n_tokens() + slot.n_prompt_tokens > n_batch) {
3147+
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
31443148
continue;
31453149
}
31463150
}
@@ -3163,28 +3167,26 @@ struct server_context {
31633167

31643168
// check if we should process the image
31653169
if (curr_chunk.tok_image) {
3166-
if (batch.has_text()) {
3167-
continue; // we cannot have both text batch and image batch
3170+
// process the image
3171+
int32_t res = server_img_process(ctx, mctx, curr_chunk, batch_embd, slot.n_past, slot.id);
3172+
if (res != 0) {
3173+
SLT_ERR(slot, "failed to process image, res = %d\n", res);
3174+
slot.release();
3175+
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
3176+
continue;
31683177
}
31693178

3170-
// encode the image
3171-
server_encode_image(slot.mctx, batch, curr_chunk, slot.n_past, slot.id);
3172-
GGML_ASSERT(batch.has_embd());
3173-
SLT_INF(slot, "image encoded, n_past = %d, n_embd_tokens = %d\n", slot.n_past, batch.n_tokens());
3174-
31753179
if (slot.params.cache_prompt) {
31763180
slot.cache_tokens.add_image_tokens(curr_chunk.tok_image);
31773181
}
31783182

3179-
slot.n_past += batch.n_tokens();
3180-
slot.n_prompt_tokens_processed += batch.n_tokens();
3181-
3182-
break; // currently, we can only process one image at a time, so we skip ALL other slots
3183+
slot.n_past += curr_chunk.n_tokens;
3184+
slot.n_prompt_tokens_processed += curr_chunk.n_tokens;
31833185
}
31843186

31853187
// add prompt tokens for processing in the current batch
3186-
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens() < n_batch) {
3187-
GGML_ASSERT(!batch.has_embd());
3188+
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
3189+
// get next token to process
31883190
auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past);
31893191
if (curr_chunk.tok_text == LLAMA_TOKEN_NULL) {
31903192
break; // end of text chunk
@@ -3193,7 +3195,7 @@ struct server_context {
31933195
// without pooling, we want to output the embeddings for all the tokens in the batch
31943196
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
31953197

3196-
common_batch_add(batch.batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd);
3198+
common_batch_add(batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd);
31973199
if (slot.params.cache_prompt) {
31983200
slot.cache_tokens.add_text_token(curr_chunk.tok_text);
31993201
}
@@ -3204,47 +3206,47 @@ struct server_context {
32043206

32053207
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
32063208

3207-
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
3209+
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
32083210

32093211
// entire prompt has been processed
32103212
if (slot.n_past == slot.n_prompt_tokens) {
32113213
slot.state = SLOT_STATE_DONE_PROMPT;
32123214

3213-
GGML_ASSERT(batch.n_tokens() > 0);
3215+
GGML_ASSERT(batch.n_tokens > 0);
32143216

32153217
common_sampler_reset(slot.smpl);
32163218

32173219
// Process all prompt tokens through sampler system
32183220
for (size_t i = 0; i < slot.cache_tokens.n_tokens(); ++i) {
3219-
auto & curr_chunk = slot.cache_tokens.get_chunk(i);
3221+
auto & curr_chunk = slot.prompt_tokens.get_chunk(i);
32203222
if (curr_chunk.tok_text != LLAMA_TOKEN_NULL) {
32213223
common_sampler_accept(slot.smpl, curr_chunk.tok_text, false);
32223224
}
32233225
}
32243226

32253227
// extract the logits only for the last token
3226-
batch.logits[batch.n_tokens() - 1] = true;
3228+
batch.logits[batch.n_tokens - 1] = true;
32273229

32283230
slot.n_decoded = 0;
3229-
slot.i_batch = batch.n_tokens() - 1;
3231+
slot.i_batch = batch.n_tokens - 1;
32303232

3231-
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens());
3233+
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
32323234
}
32333235
}
32343236

3235-
if (batch.n_tokens() >= n_batch) {
3237+
if (batch.n_tokens >= n_batch) {
32363238
break;
32373239
}
32383240
}
32393241
}
32403242

3241-
if (batch.n_tokens() == 0) {
3243+
if (batch.n_tokens == 0) {
32423244
SRV_WRN("%s", "no tokens to decode\n");
32433245
return;
32443246
}
32453247

32463248
// debug
3247-
SRV_DBG("decoding %s batch, n_tokens = %d\n", batch.has_embd() ? "embd" : "text", batch.n_tokens());
3249+
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
32483250

32493251
if (slot_batched) {
32503252
// make sure we're in the right embedding mode
@@ -3254,32 +3256,22 @@ struct server_context {
32543256
}
32553257

32563258
// process the created batch of tokens
3257-
for (int32_t i = 0; i < batch.n_tokens(); i += n_batch) {
3258-
const int32_t n_tokens = std::min(n_batch, batch.n_tokens() - i);
3259+
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
3260+
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
32593261

3260-
// TODO @ngxson : hacky here, we don't want to split the embd batch
3261-
llama_batch batch_view = batch.has_embd() ? batch.batch : llama_batch{
3262+
llama_batch batch_view = llama_batch{
32623263
n_tokens,
3263-
batch.batch.token + i,
3264+
batch.token + i,
32643265
nullptr,
3265-
batch.batch.pos + i,
3266-
batch.batch.n_seq_id + i,
3267-
batch.batch.seq_id + i,
3268-
batch.batch.logits + i,
3266+
batch.pos + i,
3267+
batch.n_seq_id + i,
3268+
batch.seq_id + i,
3269+
batch.logits + i,
32693270
};
32703271

3271-
// TODO @ngxson : maybe move this to llama_batch_ext
3272-
if (batch.has_embd() && mtmd_decode_use_non_causal(mctx)) {
3273-
llama_set_causal_attn(ctx, false);
3274-
}
3275-
32763272
const int ret = llama_decode(ctx, batch_view);
32773273
metrics.on_decoded(slots);
32783274

3279-
if (batch.has_embd() && mtmd_decode_use_non_causal(mctx)) {
3280-
llama_set_causal_attn(ctx, true);
3281-
}
3282-
32833275
if (ret != 0) {
32843276
if (n_batch == 1 || ret < 0) {
32853277
// if you get here, it means the KV cache is full - try increasing it via the context size

examples/server/utils.hpp

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,8 @@ static std::vector<common_adapter_lora_info> parse_lora_request(
963963
// (may need to refactor in near future)
964964
//
965965

966+
// each chunk can contain either one SINGLE text token or an image (multiple token embeddings)
967+
// this is to simplify the logic of KV cache management
966968
struct server_inp_chunk {
967969
size_t n_tokens = 1; // always 1 in case of text
968970
llama_token tok_text;
@@ -981,6 +983,15 @@ struct server_inp_chunk {
981983
* server_inputs is a helper to manage the input tokens and image for the server.
982984
*
983985
* the difference between server_inputs and mtmd_input_chunks is that each chunk of server_inputs only contains a single text token, but text chunk of mtmd_input_chunks can contain multiple tokens.
986+
*
987+
* for example, server_inputs may contain 5 text tokens followed by 1 image chunk:
988+
* 1 41 2635 325 463 <image of 15 tokens>
989+
*
990+
* in this example:
991+
* - n_tokens() returns 5+15 = 20 total tokens
992+
* - get_chunk(1) returns chunk containing token ID 41
993+
* - get_chunk(5) returns image chunk (15 tokens)
994+
* - get_chunk(7) returns same image chunk
984995
*
985996
* it is made this way to simplify the logic of KV cache management.
986997
*/
@@ -1079,6 +1090,7 @@ struct server_inputs {
10791090
return ret;
10801091
}
10811092

1093+
// make sure all text tokens are within the vocab range
10821094
bool validate(llama_token max_vocab_id) const {
10831095
for (const auto & chunk : chunks) {
10841096
if (!chunk.tok_image) {
@@ -1090,24 +1102,26 @@ struct server_inputs {
10901102
return true;
10911103
}
10921104

1105+
// pos is also referred as logical index
10931106
server_inp_chunk & get_chunk(size_t pos) {
1094-
return chunks[get_chunk_idx(pos)];
1107+
size_t physical_idx = get_chunk_physical_idx(pos);
1108+
return chunks[physical_idx];
10951109
}
10961110

1097-
size_t get_chunk_idx(size_t pos) const {
1111+
// returns physical_index
1112+
size_t get_chunk_physical_idx(size_t logical_idx) const {
10981113
size_t current_pos = 0;
10991114
for (size_t i = 0; i < chunks.size(); ++i) {
11001115
const auto & chunk = chunks[i];
11011116
size_t chunk_end_pos = current_pos + chunk.n_tokens;
1102-
if (pos < chunk_end_pos) {
1117+
if (logical_idx < chunk_end_pos) {
11031118
// The target position 'pos' falls within this chunk
11041119
return i;
11051120
}
1106-
11071121
current_pos = chunk_end_pos;
11081122
}
11091123
// If the loop finishes, 'pos' is >= the total number of logical positions
1110-
return chunks.size();
1124+
throw std::out_of_range("Position out of range");
11111125
}
11121126

11131127
// same idea with std::vector<llama_token> resize()
@@ -1164,7 +1178,7 @@ struct server_inputs {
11641178

11651179
// helper struct to make working with embd batch easier
11661180
// note: this will be removed after llama_batch_ext refactoring
1167-
struct server_batch {
1181+
struct server_batch_embd {
11681182
std::vector<llama_pos> pos;
11691183
std::vector<llama_token> token;
11701184
std::vector<int32_t> n_seq_id;
@@ -1174,8 +1188,8 @@ struct server_batch {
11741188

11751189
llama_batch batch;
11761190

1177-
server_batch() : server_batch(1) {}
1178-
server_batch(int32_t n_tokens) {
1191+
server_batch_embd() : server_batch_embd(1) {}
1192+
server_batch_embd(int32_t n_tokens) {
11791193
token .resize(n_tokens);
11801194
pos .resize(n_tokens);
11811195
n_seq_id.resize(n_tokens);
@@ -1233,23 +1247,69 @@ struct server_batch {
12331247
};
12341248

12351249
// TODO @ngxson : quite hacky for now, but just to see if it works
1236-
static int32_t server_encode_image(mtmd_context * mctx, server_batch & batch_out, server_inp_chunk & chunk, llama_pos n_past, llama_seq_id seq_id) {
1250+
static int32_t server_img_process(
1251+
llama_context * ctx,
1252+
mtmd_context * mctx,
1253+
server_inp_chunk & chunk,
1254+
server_batch_embd & batch,
1255+
llama_pos n_past,
1256+
int slot_id) {
12371257
GGML_ASSERT(chunk.tok_image);
1238-
batch_out.clear();
1239-
1240-
int64_t t0 = ggml_time_ms();
1241-
LOG_INF("encoding image...\n");
1242-
int32_t ret = mtmd_encode(mctx, chunk.tok_image.get());
1243-
if (ret != 0) {
1244-
LOG_ERR("failed to encode image\n");
1245-
return ret;
1258+
int32_t ret;
1259+
1260+
// encode the image
1261+
{
1262+
int64_t t0 = ggml_time_ms();
1263+
SRV_INF("encoding image (%d tokens)...\n", (int)chunk.n_tokens);
1264+
ret = mtmd_encode(mctx, chunk.tok_image.get());
1265+
if (ret != 0) {
1266+
SRV_ERR("failed to encode image, status = %d\n", ret);
1267+
return ret;
1268+
}
1269+
SRV_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
12461270
}
1247-
LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
12481271

1249-
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tok_image.get());
12501272
float * embd = mtmd_get_output_embd(mctx);
1251-
batch_out.reserve_embd_batch(embd, n_tokens, n_past, seq_id);
1252-
return ret;
1273+
// decode the embeddings
1274+
int64_t t1 = ggml_time_ms();
1275+
int32_t n_embd = llama_model_n_embd(llama_get_model(ctx));
1276+
int32_t n_tokens = chunk.n_tokens;
1277+
int32_t n_batch = batch.pos.size();
1278+
int32_t i_batch = 0;
1279+
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
1280+
// split into batches
1281+
while (i_batch < n_img_batches) {
1282+
int32_t pos_offset = i_batch*n_batch;
1283+
int32_t n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
1284+
float * embd_batch = embd + pos_offset*n_embd;
1285+
batch.clear();
1286+
batch.reserve_embd_batch(embd_batch, n_tokens_batch, n_past, slot_id);
1287+
1288+
SRV_INF("decoding embd batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
1289+
1290+
// TODO @ngxson : maybe move this to llama_batch_ext
1291+
if (mtmd_decode_use_non_causal(mctx)) {
1292+
llama_set_causal_attn(ctx, false);
1293+
}
1294+
1295+
ret = llama_decode(ctx, batch.batch);
1296+
if (ret != 0) {
1297+
LOG_ERR("failed to decode image\n");
1298+
llama_set_causal_attn(ctx, true); // restore causal attn
1299+
return ret;
1300+
}
1301+
1302+
if (mtmd_decode_use_non_causal(mctx)) {
1303+
llama_set_causal_attn(ctx, true);
1304+
}
1305+
1306+
i_batch++;
1307+
n_past += n_tokens_batch;
1308+
}
1309+
SRV_INF("image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
1310+
1311+
batch.clear();
1312+
return 0;
12531313
}
12541314

12551315
// hacky, support text-only for now

0 commit comments

Comments
 (0)