4
4
#include " llama.h"
5
5
#include " log.h"
6
6
7
-
8
7
#include < limits.h>
9
- #include < iostream>
10
8
11
9
#include < algorithm>
12
10
#include < cmath>
@@ -132,22 +130,27 @@ static bool diffusion_step_callback(int32_t step,
132
130
133
131
callback_data * data = static_cast <callback_data *>(user_data);
134
132
135
- // Clear screen and redraw everything each time to avoid cursor position issues
136
- printf (" \r\033 [2J\033 [H" ); // Clear screen and move cursor to top-left
137
-
138
- // Print progress bar
139
- int progress_percent = (step * 100 ) / total_steps;
140
- int progress_bars = (step * 50 ) / total_steps;
141
- printf (" diffusion step: %d/%d [%s%s] %d%%\n " ,
142
- step,
143
- total_steps,
144
- std::string (progress_bars, ' =' ).c_str (),
145
- std::string (50 - progress_bars, ' ' ).c_str (),
146
- progress_percent);
147
-
148
- // Print current content if in visual mode
133
+ auto print_progress_bar = [](int32_t step, int32_t total_steps) {
134
+ int progress_percent = (step * 100 ) / total_steps;
135
+ int progress_bars = (step * 50 ) / total_steps;
136
+ LOG_INF (" \r diffusion step: %d/%d [%s%s] %d%%" ,
137
+ step,
138
+ total_steps,
139
+ std::string (progress_bars, ' =' ).c_str (),
140
+ std::string (50 - progress_bars, ' ' ).c_str (),
141
+ progress_percent);
142
+ };
143
+
149
144
if (data->diff_params ->visual_mode ) {
150
- std::string current_text = " " ;
145
+ // Visual mode: clear
146
+ LOG_INF (" \033 [2J\033 [H" ); // Clear screen and move cursor to top-left
147
+
148
+ print_progress_bar (step, total_steps);
149
+
150
+ LOG_INF (" \n " );
151
+
152
+ std::string current_text = " " ;
153
+
151
154
for (int32_t i = data->n_input ; i < n_tokens; i++) {
152
155
std::string token_str;
153
156
if (tokens[i] != llama_vocab_mask (data->vocab )) {
@@ -158,14 +161,16 @@ static bool diffusion_step_callback(int32_t step,
158
161
token_str = piece;
159
162
}
160
163
} else {
161
- token_str = " _ " ; // Use underscore for mask tokens to show progress
164
+ token_str = " " ;
162
165
}
166
+
163
167
current_text += token_str;
164
168
}
165
- printf (" %s\n " , current_text.c_str ());
169
+
170
+ LOG_INF (" %s\n " , current_text.c_str ());
171
+ } else {
172
+ print_progress_bar (step, total_steps);
166
173
}
167
-
168
- fflush (stdout);
169
174
170
175
return true ;
171
176
}
@@ -493,8 +498,7 @@ static void diffusion_generate(llama_context * ctx,
493
498
int64_t time_end = ggml_time_us ();
494
499
total_time += time_end - time_start;
495
500
496
- // Print final timing info
497
- LOG_INF (" total time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n " ,
501
+ LOG_INF (" \n total time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n " ,
498
502
total_time / 1000.0 ,
499
503
total_time / 1000.0 / params.steps ,
500
504
total_sampling_time / 1000.0 / params.steps );
@@ -584,25 +588,51 @@ int main(int argc, char ** argv) {
584
588
585
589
const llama_vocab * vocab = llama_model_get_vocab (model);
586
590
591
+ std::string formatted_prompt = format_input_text (params.prompt , params.system_prompt , params.enable_chat_template , model);
592
+
593
+ std::vector<llama_token> input_tokens = common_tokenize (vocab,
594
+ formatted_prompt,
595
+ /* add special tokens*/ true ,
596
+ /* parse special*/ true );
597
+
598
+ int n_input = input_tokens.size ();
599
+
600
+ if (n_input >= params.n_ctx ) {
601
+ LOG_ERR (" error: input too long (%d tokens), max context is %d\n " , n_input, params.n_ctx );
602
+ llama_free (ctx);
603
+ llama_model_free (model);
604
+ return 1 ;
605
+ }
606
+
587
607
llama_token mask_token_id = llama_vocab_mask (vocab);
608
+
588
609
GGML_ASSERT (mask_token_id != LLAMA_TOKEN_NULL);
610
+
589
611
bool visual_mode = params.diffusion .visual_mode ;
612
+
613
+ int32_t n_generated = 0 ;
614
+ std::vector<llama_token> output_tokens (params.n_ubatch );
615
+
590
616
struct diffusion_params diff_params;
617
+
591
618
char shift_logits_str[8 ];
592
619
if (llama_model_meta_val_str (model, " diffusion.shift_logits" , shift_logits_str, sizeof (shift_logits_str)) >= 0 ) {
593
620
diff_params.shift_logits = (strcmp (shift_logits_str, " true" ) == 0 );
594
621
} else {
595
622
diff_params.shift_logits = true ;
596
623
}
624
+
597
625
// Use either eps or block length, but not both
598
626
GGML_ASSERT ((params.diffusion .eps == 0 ) ^ (params.diffusion .block_length == 0 ));
627
+
599
628
if (params.diffusion .eps ) {
600
629
diff_params.schedule = TIMESTEP_BASED;
601
630
diff_params.eps = params.diffusion .eps ;
602
631
} else if (params.diffusion .block_length ) {
603
632
diff_params.schedule = BLOCK_BASED;
604
633
diff_params.block_length = params.diffusion .block_length ;
605
634
}
635
+
606
636
diff_params.mask_token_id = mask_token_id;
607
637
diff_params.seed = params.sampling .seed ;
608
638
diff_params.temperature = params.sampling .temp ;
@@ -613,64 +643,51 @@ int main(int argc, char ** argv) {
613
643
diff_params.top_k = params.sampling .top_k ;
614
644
diff_params.visual_mode = params.diffusion .visual_mode ;
615
645
diff_params.add_gumbel_noise = params.diffusion .add_gumbel_noise ;
616
- diff_params.step_callback = diffusion_step_callback;
617
- // 交互模式
618
- if (params.prompt .empty ()) {
619
- LOG_INF (" 进入交互模式,输入 exit 或 quit 退出。\n " );
620
- std::string user_prompt;
621
- while (true ) {
622
- printf (" \n 请输入问题:" );
623
- std::getline (std::cin, user_prompt);
624
- if (user_prompt.empty ()) continue ;
625
- if (user_prompt == " exit" || user_prompt == " quit" ) break ;
626
- std::string formatted_prompt = format_input_text (user_prompt, params.system_prompt , params.enable_chat_template , model);
627
- std::vector<llama_token> input_tokens = common_tokenize (vocab, formatted_prompt, true , true );
628
- int n_input = input_tokens.size ();
629
- if (n_input >= params.n_ctx ) {
630
- LOG_ERR (" error: input too long (%d tokens), max context is %d\n " , n_input, params.n_ctx );
631
- continue ;
632
- }
633
- std::vector<llama_token> output_tokens (params.n_ubatch );
634
- callback_data cb_data = { &diff_params, vocab, n_input };
635
- diff_params.step_callback_user_data = &cb_data;
636
- int n_generated = 0 ;
637
- diffusion_generate (ctx, input_tokens.data (), output_tokens.data (), n_input, diff_params, n_generated);
638
- output_tokens.erase (output_tokens.begin (), output_tokens.begin () + n_input);
639
- std::string output_data = common_detokenize (vocab, output_tokens, false );
640
- // LOG_INF("\n%s\n", output_data.c_str());
641
- }
642
- llama_free (ctx);
643
- llama_model_free (model);
644
- llama_backend_free ();
645
- return 0 ;
646
+
647
+ diff_params.step_callback = diffusion_step_callback;
648
+ callback_data cb_data = { &diff_params, vocab, n_input };
649
+ diff_params.step_callback_user_data = &cb_data;
650
+
651
+ const char * alg_names[] = { " ORIGIN" , " ENTROPY_BASED" , " MARGIN_BASED" , " RANDOM" , " CONFIDENCE_BASED" };
652
+ const char * sched_names[] = { " TIMESTEP_BASED" , " BLOCK_BASED" };
653
+ const char * alg_name =
654
+ (diff_params.algorithm >= 0 && diff_params.algorithm <= 4 ) ? alg_names[diff_params.algorithm ] : " UNKNOWN" ;
655
+ const char * sched_name =
656
+ (diff_params.schedule >= 0 && diff_params.schedule <= 1 ) ? sched_names[diff_params.schedule ] : " UNKNOWN" ;
657
+
658
+ LOG_INF (" diffusion_params: - %-25s llama_token = %d\n " , " mask_token_id" , mask_token_id);
659
+ LOG_INF (" diffusion_params: - %-25s u32 = %d\n " , " steps" , diff_params.steps );
660
+ LOG_INF (" diffusion_params: - %-25s u32 = %d\n " , " max_length" , diff_params.max_length );
661
+ LOG_INF (" diffusion_params: - %-25s enum = %d (%s)\n " , " algorithm" , diff_params.algorithm , alg_name);
662
+ LOG_INF (" diffusion_params: - %-25s enum = %d (%s)\n " , " schedule" , diff_params.schedule , sched_name);
663
+ LOG_INF (" diffusion_params: - %-25s f32 = %.3f\n " , " temperature" , diff_params.temperature );
664
+ if (diff_params.schedule == TIMESTEP_BASED) {
665
+ LOG_INF (" diffusion_params: - %-25s f32 = %.6f\n " , " eps" , diff_params.eps );
666
+ LOG_INF (" diffusion_params: - %-25s f32 = %.3f\n " , " alg_temp" , diff_params.alg_temp );
646
667
}
647
- // 非交互模式,原有逻辑
648
- std::string formatted_prompt = format_input_text (params.prompt , params.system_prompt , params.enable_chat_template , model);
649
- std::vector<llama_token> input_tokens = common_tokenize (vocab, formatted_prompt, true , true );
650
- int n_input = input_tokens.size ();
651
- if (n_input >= params.n_ctx ) {
652
- LOG_ERR (" error: input too long (%d tokens), max context is %d\n " , n_input, params.n_ctx );
653
- llama_free (ctx);
654
- llama_model_free (model);
655
- return 1 ;
668
+ if (diff_params.schedule == BLOCK_BASED) {
669
+ LOG_INF (" diffusion_params: - %-25s u32 = %d\n " , " block_length" , diff_params.block_length );
670
+ LOG_INF (" diffusion_params: - %-25s f32 = %.3f\n " , " cfg_scale" , diff_params.cfg_scale );
656
671
}
657
- std::vector<llama_token> output_tokens (params.n_ubatch );
658
- callback_data cb_data = { &diff_params, vocab, n_input };
659
- diff_params.step_callback_user_data = &cb_data;
660
- int n_generated = 0 ;
672
+
661
673
diffusion_generate (ctx, input_tokens.data (), output_tokens.data (), n_input, diff_params, n_generated);
674
+
662
675
if (n_generated > 0 ) {
663
676
if (visual_mode) {
677
+ // clear screen and move cursor to top-left
664
678
LOG_INF (" \033 [2J\033 [H" );
665
679
}
680
+
666
681
output_tokens.erase (output_tokens.begin (), output_tokens.begin () + n_input);
667
682
std::string output_data = common_detokenize (vocab, output_tokens, false );
668
683
LOG_INF (" \n %s\n " , output_data.c_str ());
669
684
} else {
670
685
LOG_INF (" Error: diffusion generation failed\n " );
671
686
}
687
+
672
688
llama_free (ctx);
673
689
llama_model_free (model);
674
690
llama_backend_free ();
691
+
675
692
return 0 ;
676
- }
693
+ }
0 commit comments