99#include " arg.h"
1010#include " chat.h"
1111#include " common.h"
12+ #include " diffusion.h"
1213#include " llama.h"
1314#include " log.h"
14- #include " diffusion.h"
1515
1616static std::string format_input_text (const std::string & prompt, bool use_chat_template, llama_model * model) {
1717 if (!use_chat_template) {
@@ -36,7 +36,6 @@ struct callback_data {
3636 const common_params_diffusion * diff_params;
3737 const llama_vocab * vocab;
3838 int32_t n_input;
39- llama_token mask_token_id; // Store mask token separately since it's not in diffusion params
4039};
4140
4241static bool diffusion_step_callback (int32_t step, int32_t total_steps, const llama_token * tokens, int32_t n_tokens,
@@ -46,8 +45,8 @@ static bool diffusion_step_callback(int32_t step, int32_t total_steps, const lla
4645 auto print_progress_bar = [](int32_t step, int32_t total_steps) {
4746 int progress_percent = (step * 100 ) / total_steps;
4847 int progress_bars = (step * 50 ) / total_steps;
49- std::cerr << " diffusion step: " << step << " /" << total_steps << " [" << std::string (progress_bars, ' =' )
50- << std::string (50 - progress_bars, ' ' ) << " ] " << progress_percent << " %\n " ;
48+ std::cerr << " \r diffusion step: " << step << " /" << total_steps << " [" << std::string (progress_bars, ' =' )
49+ << std::string (50 - progress_bars, ' ' ) << " ] " << progress_percent << " %" ;
5150 };
5251
5352 if (data->diff_params ->visual_mode ) {
@@ -56,11 +55,13 @@ static bool diffusion_step_callback(int32_t step, int32_t total_steps, const lla
5655
5756 print_progress_bar (step, total_steps);
5857
58+ std::cerr << " \n " ;
59+
5960 std::string current_text = " " ;
6061
6162 for (int32_t i = data->n_input ; i < n_tokens; i++) {
6263 std::string token_str;
63- if (tokens[i] != data->mask_token_id ) {
64+ if (tokens[i] != llama_vocab_mask ( data->vocab ) ) {
6465 char piece[256 ];
6566 int n_chars = llama_token_to_piece (data->vocab , tokens[i], piece, sizeof (piece), 0 , false );
6667 if (n_chars > 0 ) {
@@ -135,9 +136,8 @@ int main(int argc, char ** argv) {
135136 std::string formatted_prompt = format_input_text (params.prompt , params.enable_chat_template , model);
136137
137138 std::vector<llama_token> input_tokens = common_tokenize (vocab, formatted_prompt,
138- true , // add_special tokens
139- true // parse_special
140- );
139+ /* add special tokens*/ true ,
140+ /* parse special*/ true );
141141 int n_input = input_tokens.size ();
142142
143143 if (n_input >= params.n_ctx ) {
@@ -148,14 +148,14 @@ int main(int argc, char ** argv) {
148148 }
149149
150150 struct diffusion_params ldiff_params = diffusion_default_params ();
151- ldiff_params.steps = params.diffusion .steps ;
152- ldiff_params.eps = params.diffusion .eps ;
153- ldiff_params.temperature = params.sampling .temp ;
154- ldiff_params.top_p = params.sampling .top_p ;
155- ldiff_params.top_k = params.sampling .top_k ;
156- ldiff_params.algorithm = static_cast <enum diffusion_algorithm>(params.diffusion .algorithm );
157- ldiff_params.alg_temp = params.diffusion .alg_temp ;
158- ldiff_params.seed = params.sampling .seed ;
151+ ldiff_params.steps = params.diffusion .steps ;
152+ ldiff_params.eps = params.diffusion .eps ;
153+ ldiff_params.temperature = params.sampling .temp ;
154+ ldiff_params.top_p = params.sampling .top_p ;
155+ ldiff_params.top_k = params.sampling .top_k ;
156+ ldiff_params.algorithm = static_cast <enum diffusion_algorithm>(params.diffusion .algorithm );
157+ ldiff_params.alg_temp = params.diffusion .alg_temp ;
158+ ldiff_params.seed = params.sampling .seed ;
159159
160160 llama_token mask_token_id = llama_vocab_mask (vocab);
161161 GGML_ASSERT (mask_token_id != LLAMA_TOKEN_NULL);
@@ -169,38 +169,36 @@ int main(int argc, char ** argv) {
169169
170170 ldiff_params.mask_token_id = mask_token_id;
171171
172- callback_data cb_data = { ¶ms.diffusion , vocab, n_input, mask_token_id };
172+ callback_data cb_data = { ¶ms.diffusion , vocab, n_input };
173173
174174 ldiff_params.step_callback = diffusion_step_callback;
175175 ldiff_params.step_callback_user_data = &cb_data;
176176
177- int32_t n_generated = 0 ;
177+ int32_t n_generated = 0 ;
178178
179- int64_t t1 = ggml_time_us ();
180- llama_token * generated = diffusion_generate (ctx, input_tokens.data (), n_input, params.diffusion .max_length ,
181- ldiff_params, &n_generated);
179+ int64_t t1 = ggml_time_us ();
180+ std::vector<llama_token> output_tokens (params.diffusion .max_length );
181+ diffusion_generate (ctx, input_tokens.data (), output_tokens.data (), n_input, params.diffusion .max_length ,
182+ ldiff_params, &n_generated);
182183 int64_t t2 = ggml_time_us ();
183- if (params.diffusion .visual_mode ) {
184- std::cerr << " \033 [2J\033 [H" ; // Clear screen and move cursor to top-left
185- } else {
186- std::cerr << " \r " << std::string (80 , ' ' ) << " \r " << std::flush;
187- }
188-
189- if (generated && n_generated > 0 ) {
190- std::vector<llama_token> output_tokens (generated + n_input, generated + n_generated);
191184
185+ if (n_generated > 0 ) {
186+ if (params.diffusion .visual_mode ) {
187+ // clear screen and move cursor to top-left
188+ std::cerr << " \033 [2J\033 [H" ;
189+ }
190+ output_tokens.erase (output_tokens.begin (), output_tokens.begin () + n_input);
192191 std::string output_data = common_detokenize (vocab, output_tokens, false );
193- std::cout << output_data << std::endl;
194-
195- delete[] generated;
192+ std::cout << " \n " << output_data << " \n " ;
196193 } else {
197- std::cerr << " Error: diffusion generation failed" << std::endl ;
194+ std::cerr << " Error: diffusion generation failed\n " ;
198195 llama_free (ctx);
199196 llama_model_free (model);
200197 return 1 ;
201198 }
202199
203- std::cerr << " diffusion time: " << (t2 - t1)/1000.0 << " ms time per step: " << (t2 - t1)/1000.0 /params.diffusion .steps << " ms" << std::endl;
200+ std::cerr << " diffusion time: " << (t2 - t1) / 1000.0
201+ << " ms time per step: " << (t2 - t1) / 1000.0 / params.diffusion .steps << " ms" << std::endl;
204202
205203 llama_free (ctx);
206204 llama_model_free (model);
0 commit comments