Skip to content

Commit 16affc5

Browse files
committed
add --truncate-embed
1 parent 0a5a3b5 commit 16affc5

File tree

4 files changed

+36
-5
lines changed

4 files changed

+36
-5
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2748,6 +2748,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
27482748
params.embedding = true;
27492749
}
27502750
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
2751+
add_opt(common_arg(
2752+
{"--truncate-embed"},
2753+
string_format("allow truncation for embedding tasks to handle large inputs (default: %s)", params.truncate_embed ? "enabled" : "disabled"),
2754+
[](common_params & params) {
2755+
params.truncate_embed = true;
2756+
}
2757+
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TRUNCATE_EMBED"));
27512758
add_opt(common_arg(
27522759
{"--reranking", "--rerank"},
27532760
string_format("enable reranking endpoint on server (default: %s)", "disabled"),

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ struct common_params {
356356

357357
// embedding
358358
bool embedding = false; // get only sentence embedding
359+
bool truncate_embed = false; // allow truncation for embedding tasks to handle large inputs
359360
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
360361
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
361362
std::string embd_sep = "\n"; // separator of embeddings

tools/server/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ The project is under active development, and we are [looking for feedback and co
159159
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
160160
| `--no-webui` | Disable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_NO_WEBUI) |
161161
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
162+
| `--truncate-embed` | allow truncation for embedding tasks to handle large inputs (default: disabled)<br/>(env: LLAMA_ARG_TRUNCATE_EMBED) |
162163
| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
163164
| `--api-key KEY` | API key to use for authentication (default: none)<br/>(env: LLAMA_API_KEY) |
164165
| `--api-key-file FNAME` | path to file containing API keys (default: none) |
@@ -636,6 +637,8 @@ Returns a JSON object with a field `prompt` containing a string of the input mes
636637

637638
The same as [the embedding example](../embedding) does.
638639

640+
**Note**: By default, embedding tasks cannot be split across multiple batches for safety. For large inputs that exceed the batch size, use the `--truncate-embed` flag to enable automatic truncation. When truncation occurs, the `truncated` field in the response will indicate this.
641+
639642
*Options:*
640643

641644
`content`: Set the text to process.
@@ -1175,6 +1178,8 @@ curl http://localhost:8080/v1/chat/completions \
11751178

11761179
This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.
11771180

1181+
**Note**: By default, embedding tasks cannot be split across multiple batches for safety. For large inputs that exceed the batch size, use the `--truncate-embed` flag to enable automatic truncation. When truncation occurs, the `truncated` field in the response will indicate this.
1182+
11781183
*Options:*
11791184

11801185
See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).

tools/server/server.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,7 @@ struct server_task_result_embd : server_task_result {
10341034
std::vector<std::vector<float>> embedding;
10351035

10361036
int32_t n_tokens;
1037+
bool truncated = false;
10371038

10381039
// OAI-compat fields
10391040
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
@@ -1052,6 +1053,7 @@ struct server_task_result_embd : server_task_result {
10521053
return json {
10531054
{"index", index},
10541055
{"embedding", embedding},
1056+
{"truncated", truncated},
10551057
};
10561058
}
10571059

@@ -1060,6 +1062,7 @@ struct server_task_result_embd : server_task_result {
10601062
{"index", index},
10611063
{"embedding", embedding[0]},
10621064
{"tokens_evaluated", n_tokens},
1065+
{"truncated", truncated},
10631066
};
10641067
}
10651068
};
@@ -1360,10 +1363,16 @@ struct server_slot {
13601363

13611364
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
13621365
// also we cannot split if the pooling would require any past tokens
1363-
bool can_split() const {
1366+
// NOTE: When embedding inputs are truncated, the resulting embedding may not fully represent
1367+
// the original input. The 'truncated' field in the response indicates when this occurs.
1368+
//
1369+
// @param truncate_embed: if true, allows splitting for embedding tasks to handle large inputs
1370+
// with automatic truncation. If false, uses original conservative logic.
1371+
bool can_split(bool truncate_embed = false) const {
13641372
return
13651373
!need_embd() ||
1366-
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
1374+
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) ||
1375+
(need_embd() && truncate_embed); // allow splitting for embedding tasks only if truncate_embed is enabled
13671376
}
13681377

13691378
bool can_batch_with(server_slot & other_slot) const {
@@ -2570,12 +2579,15 @@ struct server_context {
25702579
res->id = slot.id_task;
25712580
res->index = slot.index;
25722581
res->n_tokens = slot.n_prompt_tokens;
2582+
res->truncated = slot.truncated;
25732583
res->oaicompat = slot.params.oaicompat;
25742584

25752585
const int n_embd = llama_model_n_embd(model);
25762586

25772587
std::vector<float> embd_res(n_embd, 0.0f);
25782588

2589+
// Note: If the input was truncated (slot.truncated == true), this embedding
2590+
// represents only the processed portion of the original input
25792591
for (int i = 0; i < batch.n_tokens; ++i) {
25802592
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
25812593
continue;
@@ -3129,7 +3141,7 @@ struct server_context {
31293141
continue;
31303142
}
31313143

3132-
if (!slot.can_split()) {
3144+
if (!slot.can_split(params_base.truncate_embed)) {
31333145
if (slot.n_prompt_tokens > n_ubatch) {
31343146
slot.release();
31353147
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
@@ -3146,7 +3158,8 @@ struct server_context {
31463158
// if context shift is disabled, we make sure prompt size is smaller than KV size
31473159
// TODO: there should be a separate parameter that control prompt truncation
31483160
// context shift should be applied only during the generation phase
3149-
if (slot.n_prompt_tokens >= slot.n_ctx) {
3161+
// For embedding tasks, allow truncation even when context shift is disabled
3162+
if (slot.n_prompt_tokens >= slot.n_ctx && !slot.need_embd()) {
31503163
slot.release();
31513164
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
31523165
continue;
@@ -3185,6 +3198,11 @@ struct server_context {
31853198
slot.n_prompt_tokens = prompt_tokens.size();
31863199

31873200
SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
3201+
3202+
// Warn specifically for embedding tasks about potential quality impact
3203+
if (slot.need_embd()) {
3204+
SLT_WRN(slot, "%s", "WARNING: Embedding input was truncated. The resulting embedding may not fully represent the original input. Consider increasing context size or reducing input length for better embedding quality.");
3205+
}
31883206

31893207
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
31903208
}
@@ -3272,7 +3290,7 @@ struct server_context {
32723290
slot.n_prompt_tokens_processed = 0;
32733291
}
32743292

3275-
if (!slot.can_split()) {
3293+
if (!slot.can_split(params_base.truncate_embed)) {
32763294
// cannot fit the prompt in the current batch - will try next iter
32773295
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
32783296
continue;

0 commit comments

Comments
 (0)