Skip to content

Commit f8bc466

Browse files
committed
refactor logic adding tokens to batch
1 parent a6a3653 commit f8bc466

File tree

2 files changed

+44
-39
lines changed

2 files changed

+44
-39
lines changed

examples/server/server.cpp

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2859,7 +2859,7 @@ struct server_context {
28592859
res->id = task.id;
28602860
queue_results.send(std::move(res));
28612861
} break;
2862-
2862+
28632863
}
28642864
}
28652865

@@ -3159,49 +3159,51 @@ struct server_context {
31593159
// remove the non-common part from the cache
31603160
slot.cache_tokens.keep_until(slot.n_past);
31613161

3162-
// add prompt tokens for processing in the current batch
3163-
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens() < n_batch) {
3164-
// without pooling, we want to output the embeddings for all the tokens in the batch
3165-
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
3162+
auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past);
31663163

3167-
auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past);
3168-
if (curr_chunk.tok_image) {
3169-
// if there are already TEXT tokens in the batch, we need to process them first
3170-
if (batch.batch.n_tokens > 0) {
3171-
break;
3172-
}
3173-
// encode the image
3174-
server_encode_image(slot.mctx, batch, curr_chunk, slot.n_past, slot.id);
3175-
GGML_ASSERT(batch.has_embd());
3176-
SLT_INF(slot, "image encoded, n_past = %d, n_embd_tokens = %d\n", slot.n_past, batch.n_tokens());
3164+
// check if we should process the image
3165+
if (curr_chunk.tok_image) {
3166+
if (batch.has_text()) {
3167+
continue; // we cannot have both text batch and image batch
3168+
}
31773169

3178-
if (slot.params.cache_prompt) {
3179-
slot.cache_tokens.add_image_tokens(curr_chunk.tok_image);
3180-
}
3170+
// encode the image
3171+
server_encode_image(slot.mctx, batch, curr_chunk, slot.n_past, slot.id);
3172+
GGML_ASSERT(batch.has_embd());
3173+
SLT_INF(slot, "image encoded, n_past = %d, n_embd_tokens = %d\n", slot.n_past, batch.n_tokens());
31813174

3182-
slot.n_past += batch.n_tokens();
3183-
slot.n_prompt_tokens_processed += batch.n_tokens();
3184-
break; // we cannot have both text batch and image batch
3175+
if (slot.params.cache_prompt) {
3176+
slot.cache_tokens.add_image_tokens(curr_chunk.tok_image);
3177+
}
31853178

3186-
} else {
3187-
GGML_ASSERT(!batch.has_embd());
3188-
common_batch_add(batch.batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd);
3189-
if (slot.params.cache_prompt) {
3190-
slot.cache_tokens.add_text_token(curr_chunk.tok_text);
3191-
}
3179+
slot.n_past += batch.n_tokens();
3180+
slot.n_prompt_tokens_processed += batch.n_tokens();
31923181

3193-
slot.n_prompt_tokens_processed++;
3194-
slot.n_past++;
3195-
}
3182+
break; // currently, we can only process one image at a time, so we skip ALL other slots
31963183
}
31973184

3198-
SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
3185+
// add prompt tokens for processing in the current batch
3186+
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens() < n_batch) {
3187+
GGML_ASSERT(!batch.has_embd());
3188+
auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past);
3189+
if (curr_chunk.tok_text == LLAMA_TOKEN_NULL) {
3190+
break; // end of text chunk
3191+
}
31993192

3200-
if (batch.has_embd()) {
3201-
// currently, we can only process one image at a time, so we skip other slots
3202-
break;
3193+
// without pooling, we want to output the embeddings for all the tokens in the batch
3194+
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
3195+
3196+
common_batch_add(batch.batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd);
3197+
if (slot.params.cache_prompt) {
3198+
slot.cache_tokens.add_text_token(curr_chunk.tok_text);
3199+
}
3200+
3201+
slot.n_prompt_tokens_processed++;
3202+
slot.n_past++;
32033203
}
32043204

3205+
SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
3206+
32053207
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
32063208

32073209
// entire prompt has been processed

examples/server/utils.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ static json oaicompat_completion_params_parse(
668668
p["text"] = "<__image__>";
669669
p.erase("image_url");
670670
}
671-
}
671+
}
672672
}
673673

674674
common_chat_templates_inputs inputs;
@@ -979,9 +979,9 @@ struct server_inp_chunk {
979979

980980
/**
981981
* server_inputs is a helper to manage the input tokens and image for the server.
982-
*
982+
*
983983
* the difference between server_inputs and mtmd_input_chunks is that each chunk of server_inputs only contains a single text token, but text chunk of mtmd_input_chunks can contain multiple tokens.
984-
*
984+
*
985985
* it is made this way to simplify the logic of KV cache management.
986986
*/
987987
struct server_inputs {
@@ -1184,7 +1184,6 @@ struct server_batch {
11841184

11851185
void reserve_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
11861186
GGML_ASSERT(n_tokens <= (int32_t)pos.size());
1187-
seq_ids[n_tokens] = nullptr;
11881187
batch.n_tokens = n_tokens;
11891188
batch.embd = embd;
11901189
batch.token = nullptr;
@@ -1207,7 +1206,11 @@ struct server_batch {
12071206
}
12081207

12091208
bool has_embd() const {
1210-
return batch.embd != nullptr;
1209+
return batch.embd != nullptr && batch.n_tokens > 0;
1210+
}
1211+
1212+
bool has_text() const {
1213+
return batch.token != nullptr && batch.n_tokens > 0;
12111214
}
12121215
};
12131216

0 commit comments

Comments
 (0)