Skip to content

Commit 895d7a6

Browse files
author
anyshu
committed
/v1/completions 暂时ok
1 parent f7dfab1 commit 895d7a6

File tree

1 file changed

+82
-28
lines changed

1 file changed

+82
-28
lines changed

tools/server/server-diffusion.cpp

Lines changed: 82 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,8 @@ struct callback_data {
720720
diffusion_params * diff_params;
721721
const llama_vocab * vocab;
722722
int32_t n_input;
723+
void * slot; // For streaming updates (server_slot*)
724+
void * ctx_server; // For sending partial responses (server_context*)
723725
};
724726

725727
// Forward declarations for diffusion functions
@@ -910,6 +912,7 @@ struct swa_checkpoint {
910912
std::vector<uint8_t> data;
911913
};
912914

915+
//last output chunk, sent when the generation is finished
913916
struct server_task_result_cmpl_final : server_task_result {
914917
int index = 0;
915918

@@ -4035,7 +4038,7 @@ struct server_context {
40354038

40364039
// Set up callback
40374040
const auto & prompt_text_tokens = slot.prompt_tokens.get_text_tokens();
4038-
callback_data cb_data = { &diff_params, vocab, (int32_t)prompt_text_tokens.size() };
4041+
callback_data cb_data = { &diff_params, vocab, (int32_t)prompt_text_tokens.size(), &slot, this };
40394042
diff_params.step_callback = diffusion_step_callback;
40404043
diff_params.step_callback_user_data = &cb_data;
40414044

@@ -4075,6 +4078,20 @@ struct server_context {
40754078
slot.generated_text = output_text;
40764079
slot.generated_tokens = filtered_tokens;
40774080

4081+
// For streaming mode, send the complete text as a single chunk before the final response
4082+
if (slot.params.stream && !output_text.empty()) {
4083+
completion_token_output result;
4084+
result.tok = -1; // No specific token for diffusion output
4085+
result.text_to_send = output_text;
4086+
result.prob = 1.0f;
4087+
4088+
slot.n_decoded = filtered_tokens.size();
4089+
slot.has_next_token = false;
4090+
slot.stop = STOP_TYPE_LIMIT;
4091+
4092+
send_partial_response(slot, result, false);
4093+
}
4094+
40784095
send_final_response(slot);
40794096
} else {
40804097
send_error(slot, "Diffusion generation failed");
@@ -4683,10 +4700,31 @@ static bool diffusion_step_callback(int32_t step,
46834700
const llama_token * tokens,
46844701
int32_t n_tokens,
46854702
void * user_data) {
4686-
(void) user_data;
4687-
46884703
callback_data * data = static_cast<callback_data *>(user_data);
46894704

4705+
SRV_INF("%s", "\033[2J\033[H");
4706+
SRV_INF("diffusion_step_callback ENTRY: step=%d/%d, n_tokens=%d, n_input=%d, expected_range=%d\n",
4707+
step, total_steps, n_tokens, data->n_input, data->diff_params->max_length);
4708+
4709+
// Debug: Print first few tokens to see what's in the array
4710+
std::string current_text = "Token array sample (first 20 after input): ";
4711+
for (int32_t i = data->n_input; i < n_tokens; i++) {
4712+
std::string token_str;
4713+
if (tokens[i] != llama_vocab_mask(data->vocab)) {
4714+
char piece[256];
4715+
int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
4716+
if (n_chars > 0) {
4717+
piece[n_chars] = '\0';
4718+
token_str = piece;
4719+
}
4720+
} else {
4721+
token_str = "_"; // Represent mask token as "_"
4722+
}
4723+
4724+
current_text += token_str;
4725+
}
4726+
SRV_INF("%s\n", current_text.c_str());
4727+
46904728
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
46914729
int progress_percent = (step * 100) / total_steps;
46924730
int progress_bars = (step * 50) / total_steps;
@@ -4697,38 +4735,54 @@ static bool diffusion_step_callback(int32_t step,
46974735
std::string(50 - progress_bars, ' ').c_str(),
46984736
progress_percent);
46994737
};
4738+
4739+
// Count mask and real tokens for debug logging
4740+
llama_token mask_token = llama_vocab_mask(data->vocab);
4741+
int32_t mask_count = 0;
4742+
int32_t real_token_count = 0;
4743+
for (int32_t i = data->n_input; i < n_tokens; i++) {
4744+
if (tokens[i] == mask_token) {
4745+
mask_count++;
4746+
} else {
4747+
real_token_count++;
4748+
}
4749+
}
47004750

47014751
if (data->diff_params->visual_mode) {
4702-
// Visual mode: clear
4703-
SRV_INF("%s", "\033[2J\033[H"); // Clear screen and move cursor to top-left
4704-
4752+
// Visual mode: clear screen
4753+
// SRV_INF("%s", "\033[2J\033[H");
47054754
print_progress_bar(step, total_steps);
4706-
4707-
SRV_INF("%s", "\n");
4708-
4709-
std::string current_text = " ";
4710-
4711-
for (int32_t i = data->n_input; i < n_tokens; i++) {
4712-
std::string token_str;
4713-
if (tokens[i] != llama_vocab_mask(data->vocab)) {
4714-
char piece[256];
4715-
int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
4716-
if (n_chars > 0) {
4717-
piece[n_chars] = '\0';
4718-
token_str = piece;
4719-
}
4720-
} else {
4721-
token_str = " ";
4722-
}
4723-
4724-
current_text += token_str;
4725-
}
4726-
4727-
SRV_INF("%s\n", current_text.c_str());
4755+
SRV_INF("\n%s\n", current_text.c_str());
47284756
} else {
47294757
print_progress_bar(step, total_steps);
47304758
}
47314759

4760+
// Send streaming update if slot is available and streaming is enabled
4761+
server_slot * slot = static_cast<server_slot*>(data->slot);
4762+
server_context * ctx_server = static_cast<server_context*>(data->ctx_server);
4763+
4764+
SRV_DBG("diffusion_step_callback: step=%d/%d, mask=%d, real=%d, stream=%d, text_len=%zu\n",
4765+
step, total_steps, mask_count, real_token_count,
4766+
(slot && slot->params.stream) ? 1 : 0,
4767+
current_text.length());
4768+
4769+
if (slot && ctx_server && slot->params.stream) {
4770+
// Send progress update more frequently to show diffusion process
4771+
const int update_interval = std::max(1, total_steps / 50); // Send ~50 updates
4772+
if (step % update_interval == 0 || step == total_steps - 1) {
4773+
SRV_INF("Sending diffusion update: step=%d, interval=%d, first_100_chars='%.100s...'\n",
4774+
step, update_interval, current_text.c_str());
4775+
4776+
completion_token_output progress_token;
4777+
progress_token.tok = -1; // Special value for progress
4778+
progress_token.text_to_send = current_text;
4779+
progress_token.prob = 1.0f;
4780+
4781+
// Use is_progress=false to send actual content instead of progress info
4782+
ctx_server->send_partial_response(*slot, progress_token, false);
4783+
}
4784+
}
4785+
47324786
return true;
47334787
}
47344788

0 commit comments

Comments
 (0)