Skip to content

Commit d31c734

Browse files
feat: Add optional prompt processing progress streaming
- Add include_prompt_progress parameter to slot_params (default: false) - Extend server_task_result_cmpl_partial with progress fields - Implement send_progress_response() function with 1% progress intervals - Add progress response in prompt processing loop - Update JSON response to include prompt_processing field when requested - Add comprehensive documentation to README.md - Ensure full backward compatibility with existing clients Closes #14685
1 parent 21c0217 commit d31c734

File tree

2 files changed

+89
-6
lines changed

2 files changed

+89
-6
lines changed

tools/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to re
428428

429429
`stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.
430430

431+
`include_prompt_progress`: When `stream` is enabled, this option allows receiving prompt processing progress information before the text generation begins. The progress responses contain a `prompt_processing` field with details about the number of tokens processed and overall progress. This is useful for long prompts where users want to see evaluation progress instead of waiting silently. Default: `false` (only applies when `stream` is `true`).
432+
431433
`stop`: Specify a JSON array of stopping strings.
432434
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. Default: `[]`
433435

tools/server/server.cpp

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,10 @@ static bool server_task_type_need_logits(server_task_type task_type) {
109109
}
110110

111111
struct slot_params {
112-
bool stream = true;
113-
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
114-
bool return_tokens = false;
112+
bool stream = true;
113+
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
114+
bool return_tokens = false;
115+
bool include_prompt_progress = false; // include prompt processing progress in streaming responses
115116

116117
int32_t n_keep = 0; // number of tokens to keep from initial prompt
117118
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@@ -258,9 +259,10 @@ struct server_task {
258259
params.verbose = params_base.verbosity > 9;
259260
params.timings_per_token = json_value(data, "timings_per_token", false);
260261

261-
params.stream = json_value(data, "stream", false);
262-
params.cache_prompt = json_value(data, "cache_prompt", true);
263-
params.return_tokens = json_value(data, "return_tokens", false);
262+
params.stream = json_value(data, "stream", false);
263+
params.cache_prompt = json_value(data, "cache_prompt", true);
264+
params.return_tokens = json_value(data, "return_tokens", false);
265+
params.include_prompt_progress = json_value(data, "include_prompt_progress", false);
264266
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
265267
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
266268
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
@@ -898,6 +900,12 @@ struct server_task_result_cmpl_partial : server_task_result {
898900
completion_token_output prob_output;
899901
result_timings timings;
900902

903+
// Progress fields (only populated when is_progress_response is true)
904+
bool is_progress_response = false;
905+
int32_t n_past = 0;
906+
int32_t n_prompt_tokens_processed = 0;
907+
float progress = 0.0f;
908+
901909
// OAI-compat fields
902910
bool verbose = false;
903911
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
@@ -944,6 +952,15 @@ struct server_task_result_cmpl_partial : server_task_result {
944952
if (!prob_output.probs.empty()) {
945953
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
946954
}
955+
// include prompt processing progress if this is a progress response
956+
if (is_progress_response) {
957+
res["prompt_processing"] = json {
958+
{"n_past", n_past},
959+
{"n_prompt_tokens", n_prompt_tokens},
960+
{"n_prompt_tokens_processed", n_prompt_tokens_processed},
961+
{"progress", progress},
962+
};
963+
}
947964
return res;
948965
}
949966

@@ -2515,6 +2532,64 @@ struct server_context {
25152532
queue_results.send(std::move(res));
25162533
}
25172534

2535+
void send_progress_response(server_slot & slot) {
2536+
// Only send progress if explicitly requested and streaming is enabled
2537+
if (!slot.params.include_prompt_progress || !slot.params.stream) {
2538+
return;
2539+
}
2540+
2541+
// Calculate current progress percentage
2542+
float current_progress = slot.n_prompt_tokens > 0 ?
2543+
(float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens : 0.0f;
2544+
2545+
// Send progress updates at regular intervals (every 10% or significant changes)
2546+
static float last_progress = -1.0f;
2547+
static int last_slot_id = -1;
2548+
2549+
// Reset for new slot
2550+
if (slot.id_task != last_slot_id) {
2551+
last_progress = -1.0f;
2552+
last_slot_id = slot.id_task;
2553+
}
2554+
2555+
// Send progress if:
2556+
// 1. This is the first progress update (last_progress == -1)
2557+
// 2. Progress increased by at least 1% or processed at least 10 tokens
2558+
// 3. We've completed processing (current_progress >= 1.0)
2559+
bool should_send = (last_progress < 0.0f) ||
2560+
(current_progress - last_progress >= 0.01f) ||
2561+
(current_progress >= 1.0f && last_progress < 1.0f);
2562+
2563+
if (!should_send) {
2564+
return;
2565+
}
2566+
2567+
last_progress = current_progress;
2568+
2569+
auto res = std::make_unique<server_task_result_cmpl_partial>();
2570+
2571+
res->id = slot.id_task;
2572+
res->index = slot.index;
2573+
res->content = ""; // empty content for progress responses
2574+
res->tokens = {}; // empty tokens for progress responses
2575+
2576+
res->n_decoded = 0; // no tokens decoded yet during prompt processing
2577+
res->n_prompt_tokens = slot.n_prompt_tokens;
2578+
2579+
// Progress-specific fields
2580+
res->is_progress_response = true;
2581+
res->n_past = slot.n_past;
2582+
res->n_prompt_tokens_processed = slot.n_prompt_tokens_processed;
2583+
res->progress = current_progress;
2584+
2585+
res->verbose = slot.params.verbose;
2586+
res->oaicompat = slot.params.oaicompat;
2587+
res->oaicompat_model = slot.params.oaicompat_model;
2588+
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
2589+
2590+
queue_results.send(std::move(res));
2591+
}
2592+
25182593
void send_final_response(server_slot & slot) {
25192594
auto res = std::make_unique<server_task_result_cmpl_final>();
25202595
res->id = slot.id_task;
@@ -3334,12 +3409,18 @@ struct server_context {
33343409

33353410
slot.n_prompt_tokens_processed++;
33363411
slot.n_past++;
3412+
3413+
// Send incremental progress updates during token processing
3414+
send_progress_response(slot);
33373415
}
33383416

33393417
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
33403418

33413419
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
33423420

3421+
// Send progress response if requested
3422+
send_progress_response(slot);
3423+
33433424
// entire prompt has been processed
33443425
if (slot.n_past == slot.n_prompt_tokens) {
33453426
slot.state = SLOT_STATE_DONE_PROMPT;

0 commit comments

Comments
 (0)