@@ -720,6 +720,8 @@ struct callback_data {
720
720
diffusion_params * diff_params;
721
721
const llama_vocab * vocab;
722
722
int32_t n_input;
723
+ void * slot; // For streaming updates (server_slot*)
724
+ void * ctx_server; // For sending partial responses (server_context*)
723
725
};
724
726
725
727
// Forward declarations for diffusion functions
@@ -910,6 +912,7 @@ struct swa_checkpoint {
910
912
std::vector<uint8_t > data;
911
913
};
912
914
915
+ // last output chunk, sent when the generation is finished
913
916
struct server_task_result_cmpl_final : server_task_result {
914
917
int index = 0 ;
915
918
@@ -4035,7 +4038,7 @@ struct server_context {
4035
4038
4036
4039
// Set up callback
4037
4040
const auto & prompt_text_tokens = slot.prompt_tokens .get_text_tokens ();
4038
- callback_data cb_data = { &diff_params, vocab, (int32_t )prompt_text_tokens.size () };
4041
+ callback_data cb_data = { &diff_params, vocab, (int32_t )prompt_text_tokens.size (), &slot, this };
4039
4042
diff_params.step_callback = diffusion_step_callback;
4040
4043
diff_params.step_callback_user_data = &cb_data;
4041
4044
@@ -4075,6 +4078,20 @@ struct server_context {
4075
4078
slot.generated_text = output_text;
4076
4079
slot.generated_tokens = filtered_tokens;
4077
4080
4081
+ // For streaming mode, send the complete text as a single chunk before the final response
4082
+ if (slot.params .stream && !output_text.empty ()) {
4083
+ completion_token_output result;
4084
+ result.tok = -1 ; // No specific token for diffusion output
4085
+ result.text_to_send = output_text;
4086
+ result.prob = 1 .0f ;
4087
+
4088
+ slot.n_decoded = filtered_tokens.size ();
4089
+ slot.has_next_token = false ;
4090
+ slot.stop = STOP_TYPE_LIMIT;
4091
+
4092
+ send_partial_response (slot, result, false );
4093
+ }
4094
+
4078
4095
send_final_response (slot);
4079
4096
} else {
4080
4097
send_error (slot, " Diffusion generation failed" );
@@ -4683,10 +4700,31 @@ static bool diffusion_step_callback(int32_t step,
4683
4700
const llama_token * tokens,
4684
4701
int32_t n_tokens,
4685
4702
void * user_data) {
4686
- (void ) user_data;
4687
-
4688
4703
callback_data * data = static_cast <callback_data *>(user_data);
4689
4704
4705
+ SRV_INF (" %s" , " \033 [2J\033 [H" );
4706
+ SRV_INF (" diffusion_step_callback ENTRY: step=%d/%d, n_tokens=%d, n_input=%d, expected_range=%d\n " ,
4707
+ step, total_steps, n_tokens, data->n_input , data->diff_params ->max_length );
4708
+
4709
+ // Debug: Print first few tokens to see what's in the array
4710
+ std::string current_text = " Token array sample (first 20 after input): " ;
4711
+ for (int32_t i = data->n_input ; i < n_tokens; i++) {
4712
+ std::string token_str;
4713
+ if (tokens[i] != llama_vocab_mask (data->vocab )) {
4714
+ char piece[256 ];
4715
+ int n_chars = llama_token_to_piece (data->vocab , tokens[i], piece, sizeof (piece), 0 , false );
4716
+ if (n_chars > 0 ) {
4717
+ piece[n_chars] = ' \0 ' ;
4718
+ token_str = piece;
4719
+ }
4720
+ } else {
4721
+ token_str = " _" ; // Represent mask token as "_"
4722
+ }
4723
+
4724
+ current_text += token_str;
4725
+ }
4726
+ SRV_INF (" %s\n " , current_text.c_str ());
4727
+
4690
4728
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
4691
4729
int progress_percent = (step * 100 ) / total_steps;
4692
4730
int progress_bars = (step * 50 ) / total_steps;
@@ -4697,38 +4735,54 @@ static bool diffusion_step_callback(int32_t step,
4697
4735
std::string (50 - progress_bars, ' ' ).c_str (),
4698
4736
progress_percent);
4699
4737
};
4738
+
4739
+ // Count mask and real tokens for debug logging
4740
+ llama_token mask_token = llama_vocab_mask (data->vocab );
4741
+ int32_t mask_count = 0 ;
4742
+ int32_t real_token_count = 0 ;
4743
+ for (int32_t i = data->n_input ; i < n_tokens; i++) {
4744
+ if (tokens[i] == mask_token) {
4745
+ mask_count++;
4746
+ } else {
4747
+ real_token_count++;
4748
+ }
4749
+ }
4700
4750
4701
4751
if (data->diff_params ->visual_mode ) {
4702
- // Visual mode: clear
4703
- SRV_INF (" %s" , " \033 [2J\033 [H" ); // Clear screen and move cursor to top-left
4704
-
4752
+ // Visual mode: clear screen
4753
+ // SRV_INF("%s", "\033[2J\033[H");
4705
4754
print_progress_bar (step, total_steps);
4706
-
4707
- SRV_INF (" %s" , " \n " );
4708
-
4709
- std::string current_text = " " ;
4710
-
4711
- for (int32_t i = data->n_input ; i < n_tokens; i++) {
4712
- std::string token_str;
4713
- if (tokens[i] != llama_vocab_mask (data->vocab )) {
4714
- char piece[256 ];
4715
- int n_chars = llama_token_to_piece (data->vocab , tokens[i], piece, sizeof (piece), 0 , false );
4716
- if (n_chars > 0 ) {
4717
- piece[n_chars] = ' \0 ' ;
4718
- token_str = piece;
4719
- }
4720
- } else {
4721
- token_str = " " ;
4722
- }
4723
-
4724
- current_text += token_str;
4725
- }
4726
-
4727
- SRV_INF (" %s\n " , current_text.c_str ());
4755
+ SRV_INF (" \n %s\n " , current_text.c_str ());
4728
4756
} else {
4729
4757
print_progress_bar (step, total_steps);
4730
4758
}
4731
4759
4760
+ // Send streaming update if slot is available and streaming is enabled
4761
+ server_slot * slot = static_cast <server_slot*>(data->slot );
4762
+ server_context * ctx_server = static_cast <server_context*>(data->ctx_server );
4763
+
4764
+ SRV_DBG (" diffusion_step_callback: step=%d/%d, mask=%d, real=%d, stream=%d, text_len=%zu\n " ,
4765
+ step, total_steps, mask_count, real_token_count,
4766
+ (slot && slot->params .stream ) ? 1 : 0 ,
4767
+ current_text.length ());
4768
+
4769
+ if (slot && ctx_server && slot->params .stream ) {
4770
+ // Send progress update more frequently to show diffusion process
4771
+ const int update_interval = std::max (1 , total_steps / 50 ); // Send ~50 updates
4772
+ if (step % update_interval == 0 || step == total_steps - 1 ) {
4773
+ SRV_INF (" Sending diffusion update: step=%d, interval=%d, first_100_chars='%.100s...'\n " ,
4774
+ step, update_interval, current_text.c_str ());
4775
+
4776
+ completion_token_output progress_token;
4777
+ progress_token.tok = -1 ; // Special value for progress
4778
+ progress_token.text_to_send = current_text;
4779
+ progress_token.prob = 1 .0f ;
4780
+
4781
+ // Use is_progress=false to send actual content instead of progress info
4782
+ ctx_server->send_partial_response (*slot, progress_token, false );
4783
+ }
4784
+ }
4785
+
4732
4786
return true ;
4733
4787
}
4734
4788
0 commit comments