Skip to content

Commit de1880b

Browse files
author
anyshu
committed
正在处理v1/chat/completions
1 parent 8287c64 commit de1880b

File tree

1 file changed

+114
-27
lines changed

1 file changed

+114
-27
lines changed

tools/server/server-diffusion.cpp

Lines changed: 114 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,8 @@ struct callback_data {
722722
int32_t n_input;
723723
void * slot; // For streaming updates (server_slot*)
724724
void * ctx_server; // For sending partial responses (server_context*)
725+
std::string last_sent_text; // Track last sent text for delta calculation
726+
llama_token * last_tokens; // Track last tokens for partial text decoding
725727
};
726728

727729
// Forward declarations for diffusion functions
@@ -1294,8 +1296,15 @@ struct server_task_result_cmpl_partial : server_task_result {
12941296
});
12951297
}
12961298

1297-
for (const auto & diff : oaicompat_msg_diffs) {
1298-
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
1299+
// For diffusion tasks, oaicompat_msg_diffs will be empty
1300+
// In that case, use the content field directly
1301+
if (!oaicompat_msg_diffs.empty()) {
1302+
for (const auto & diff : oaicompat_msg_diffs) {
1303+
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
1304+
}
1305+
} else if (!content.empty() && !is_progress) {
1306+
// For diffusion or other tasks without diffs, send content directly
1307+
add_delta({{"content", content}});
12991308
}
13001309

13011310
if (!deltas.empty()) {
@@ -2936,7 +2945,11 @@ struct server_context {
29362945
res->content = tkn.text_to_send;
29372946
res->tokens = { tkn.tok };
29382947

2939-
slot.update_chat_msg(res->oaicompat_msg_diffs);
2948+
// For diffusion tasks, skip chat message diff computation because tokens are replaced, not appended
2949+
// This avoids the "Invalid diff" error when text doesn't grow monotonically
2950+
if (slot.task_type != SERVER_TASK_TYPE_DIFFUSION) {
2951+
slot.update_chat_msg(res->oaicompat_msg_diffs);
2952+
}
29402953
}
29412954

29422955
res->n_decoded = slot.n_decoded;
@@ -4036,9 +4049,18 @@ struct server_context {
40364049
diff_params.seed = slot.params.sampling.seed;
40374050
diff_params.mask_token_id = llama_vocab_mask(vocab);
40384051

4039-
// Set up callback
4052+
// Set up callback with allocated last_tokens buffer
40404053
const auto & prompt_text_tokens = slot.prompt_tokens.get_text_tokens();
4041-
callback_data cb_data = { &diff_params, vocab, (int32_t)prompt_text_tokens.size(), &slot, this };
4054+
std::vector<llama_token> last_tokens_buffer(diff_params.max_length);
4055+
callback_data cb_data = {
4056+
&diff_params,
4057+
vocab,
4058+
(int32_t)prompt_text_tokens.size(),
4059+
&slot,
4060+
this,
4061+
"",
4062+
last_tokens_buffer.data()
4063+
};
40424064
diff_params.step_callback = diffusion_step_callback;
40434065
diff_params.step_callback_user_data = &cb_data;
40444066

@@ -4078,17 +4100,31 @@ struct server_context {
40784100
slot.generated_text = output_text;
40794101
slot.generated_tokens = filtered_tokens;
40804102

4081-
// For streaming mode, send the complete text as a single chunk before the final response
4082-
if (slot.params.stream && !output_text.empty()) {
4103+
slot.n_decoded = filtered_tokens.size();
4104+
slot.has_next_token = false;
4105+
slot.stop = STOP_TYPE_LIMIT;
4106+
4107+
// For non-streaming mode or if no intermediate updates were sent,
4108+
// send the complete text as a single chunk before the final response
4109+
// In streaming mode with callbacks, the text was already sent incrementally
4110+
if (slot.params.stream) {
4111+
// Check if we need to send any remaining text that wasn't sent by callback
4112+
if (cb_data.last_sent_text != output_text && !output_text.empty()) {
4113+
std::string remaining_text = output_text.substr(cb_data.last_sent_text.length());
4114+
if (!remaining_text.empty()) {
4115+
completion_token_output result;
4116+
result.tok = -1;
4117+
result.text_to_send = remaining_text;
4118+
result.prob = 1.0f;
4119+
send_partial_response(slot, result, false);
4120+
}
4121+
}
4122+
} else if (!output_text.empty()) {
4123+
// Non-streaming: send complete text at once
40834124
completion_token_output result;
4084-
result.tok = -1; // No specific token for diffusion output
4125+
result.tok = -1;
40854126
result.text_to_send = output_text;
40864127
result.prob = 1.0f;
4087-
4088-
slot.n_decoded = filtered_tokens.size();
4089-
slot.has_next_token = false;
4090-
slot.stop = STOP_TYPE_LIMIT;
4091-
40924128
send_partial_response(slot, result, false);
40934129
}
40944130

@@ -4766,22 +4802,64 @@ static bool diffusion_step_callback(int32_t step,
47664802
(slot && slot->params.stream) ? 1 : 0,
47674803
current_text.length());
47684804

4769-
if (slot && ctx_server && slot->params.stream) {
4770-
// Send progress update more frequently to show diffusion process
4771-
const int update_interval = std::max(1, total_steps / 50); // Send ~50 updates
4772-
if (step % update_interval == 0 || step == total_steps - 1) {
4773-
SRV_INF("Sending diffusion update: step=%d, interval=%d, first_100_chars='%.100s...'\n",
4774-
step, update_interval, current_text.c_str());
4775-
4776-
completion_token_output progress_token;
4777-
progress_token.tok = -1; // Special value for progress
4778-
progress_token.text_to_send = current_text;
4779-
progress_token.prob = 1.0f;
4805+
if (slot && ctx_server && slot->params.stream) {
4806+
// Always send on first step, last step, or at regular intervals
4807+
bool should_send = (step == 0) ||
4808+
(step == total_steps - 1);
4809+
4810+
// Also send if text has changed significantly (more tokens decoded)
4811+
if (!should_send && current_text.length() > data->last_sent_text.length() + 10) {
4812+
should_send = true;
4813+
}
4814+
4815+
//for chat/completions
4816+
if (should_send) {
4817+
std::string delta_text;
4818+
// Track token changes for debugging
4819+
if (data->last_tokens && step > 0) {
4820+
int32_t changed_tokens = 0;
4821+
std::string changes_debug;
4822+
for (int32_t i = data->n_input; i < n_tokens && i < data->diff_params->max_length; i++) {
4823+
if (data->last_tokens[i] != tokens[i]) {
4824+
changed_tokens++;
4825+
if (changes_debug.length() < 200) { // Limit debug output
4826+
char old_piece[64], new_piece[64];
4827+
int old_n_chars = llama_token_to_piece(data->vocab, data->last_tokens[i], old_piece, sizeof(old_piece), 0, false);
4828+
int new_n_chars = llama_token_to_piece(data->vocab, tokens[i], new_piece, sizeof(new_piece), 0, false);
4829+
old_piece[old_n_chars] = '\0';
4830+
new_piece[new_n_chars] = '\0';
4831+
changes_debug += string_format("[%d: '%s'->'%s'] ", i - data->n_input, old_piece, new_piece);
4832+
}
4833+
}
4834+
}
4835+
if (changed_tokens > 0) {
4836+
delta_text = string_format("Token changes at step %d: %d tokens changed - %s\n",
4837+
step, changed_tokens, changes_debug.c_str());
4838+
SRV_INF("%s", delta_text.c_str());
4839+
}
4840+
}
47804841

4781-
// Use is_progress=false to send actual content instead of progress info
4782-
ctx_server->send_partial_response(*slot, progress_token, false);
4842+
if (!delta_text.empty()) {
4843+
SRV_INF("Sending diffusion delta: step=%d/%d, delta_len=%zu, delta=%s\n",
4844+
step, total_steps, delta_text.length(), delta_text.c_str());
4845+
4846+
completion_token_output progress_token;
4847+
progress_token.tok = -1; // Special value for progress
4848+
progress_token.text_to_send = delta_text;
4849+
progress_token.prob = 1.0f;
4850+
4851+
// Use is_progress=false to send actual content instead of progress info
4852+
ctx_server->send_partial_response(*slot, progress_token, false);
4853+
4854+
// Update last sent text
4855+
data->last_sent_text = current_text;
4856+
}
47834857
}
47844858
}
4859+
// Save current tokens for next comparison
4860+
if (data->last_tokens) {
4861+
std::memcpy(data->last_tokens, tokens, n_tokens * sizeof(llama_token));
4862+
}
47854863

47864864
return true;
47874865
}
@@ -5571,8 +5649,17 @@ int main(int argc, char ** argv) {
55715649
ctx_server.oai_parser_opt,
55725650
files);
55735651

5652+
// Check if this is a diffusion request by looking for diffusion-specific parameters
5653+
bool is_diffusion = data.contains("diffusion_steps") ||
5654+
data.contains("diffusion_algorithm") ||
5655+
data.contains("cfg_scale") ||
5656+
data.contains("visual_mode") ||
5657+
data.contains("max_length");
5658+
5659+
server_task_type task_type = is_diffusion ? SERVER_TASK_TYPE_DIFFUSION : SERVER_TASK_TYPE_COMPLETION;
5660+
55745661
handle_completions_impl(
5575-
SERVER_TASK_TYPE_COMPLETION,
5662+
task_type,
55765663
data,
55775664
files,
55785665
req.is_connection_closed,

0 commit comments

Comments
 (0)