Skip to content

Commit 5bf7496

Browse files
author
anyshu
committed
恢复到原来的9/16
1 parent 3f76940 commit 5bf7496

File tree

1 file changed

+84
-67
lines changed

1 file changed

+84
-67
lines changed

examples/diffusion/diffusion-cli.cpp

Lines changed: 84 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
#include "llama.h"
55
#include "log.h"
66

7-
87
#include <limits.h>
9-
#include <iostream>
108

119
#include <algorithm>
1210
#include <cmath>
@@ -132,22 +130,27 @@ static bool diffusion_step_callback(int32_t step,
132130

133131
callback_data * data = static_cast<callback_data *>(user_data);
134132

135-
// Clear screen and redraw everything each time to avoid cursor position issues
136-
printf("\r\033[2J\033[H"); // Clear screen and move cursor to top-left
137-
138-
// Print progress bar
139-
int progress_percent = (step * 100) / total_steps;
140-
int progress_bars = (step * 50) / total_steps;
141-
printf("diffusion step: %d/%d [%s%s] %d%%\n",
142-
step,
143-
total_steps,
144-
std::string(progress_bars, '=').c_str(),
145-
std::string(50 - progress_bars, ' ').c_str(),
146-
progress_percent);
147-
148-
// Print current content if in visual mode
133+
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
134+
int progress_percent = (step * 100) / total_steps;
135+
int progress_bars = (step * 50) / total_steps;
136+
LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
137+
step,
138+
total_steps,
139+
std::string(progress_bars, '=').c_str(),
140+
std::string(50 - progress_bars, ' ').c_str(),
141+
progress_percent);
142+
};
143+
149144
if (data->diff_params->visual_mode) {
150-
std::string current_text = "";
145+
// Visual mode: clear
146+
LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
147+
148+
print_progress_bar(step, total_steps);
149+
150+
LOG_INF("\n");
151+
152+
std::string current_text = " ";
153+
151154
for (int32_t i = data->n_input; i < n_tokens; i++) {
152155
std::string token_str;
153156
if (tokens[i] != llama_vocab_mask(data->vocab)) {
@@ -158,14 +161,16 @@ static bool diffusion_step_callback(int32_t step,
158161
token_str = piece;
159162
}
160163
} else {
161-
token_str = "_"; // Use underscore for mask tokens to show progress
164+
token_str = " ";
162165
}
166+
163167
current_text += token_str;
164168
}
165-
printf("%s\n", current_text.c_str());
169+
170+
LOG_INF("%s\n", current_text.c_str());
171+
} else {
172+
print_progress_bar(step, total_steps);
166173
}
167-
168-
fflush(stdout);
169174

170175
return true;
171176
}
@@ -493,8 +498,7 @@ static void diffusion_generate(llama_context * ctx,
493498
int64_t time_end = ggml_time_us();
494499
total_time += time_end - time_start;
495500

496-
// Print final timing info
497-
LOG_INF("total time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
501+
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
498502
total_time / 1000.0,
499503
total_time / 1000.0 / params.steps,
500504
total_sampling_time / 1000.0 / params.steps);
@@ -584,25 +588,51 @@ int main(int argc, char ** argv) {
584588

585589
const llama_vocab * vocab = llama_model_get_vocab(model);
586590

591+
std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model);
592+
593+
std::vector<llama_token> input_tokens = common_tokenize(vocab,
594+
formatted_prompt,
595+
/*add special tokens*/ true,
596+
/*parse special*/ true);
597+
598+
int n_input = input_tokens.size();
599+
600+
if (n_input >= params.n_ctx) {
601+
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
602+
llama_free(ctx);
603+
llama_model_free(model);
604+
return 1;
605+
}
606+
587607
llama_token mask_token_id = llama_vocab_mask(vocab);
608+
588609
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
610+
589611
bool visual_mode = params.diffusion.visual_mode;
612+
613+
int32_t n_generated = 0;
614+
std::vector<llama_token> output_tokens(params.n_ubatch);
615+
590616
struct diffusion_params diff_params;
617+
591618
char shift_logits_str[8];
592619
if (llama_model_meta_val_str(model, "diffusion.shift_logits", shift_logits_str, sizeof(shift_logits_str)) >= 0) {
593620
diff_params.shift_logits = (strcmp(shift_logits_str, "true") == 0);
594621
} else {
595622
diff_params.shift_logits = true;
596623
}
624+
597625
//Use either eps or block length, but not both
598626
GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0));
627+
599628
if (params.diffusion.eps) {
600629
diff_params.schedule = TIMESTEP_BASED;
601630
diff_params.eps = params.diffusion.eps;
602631
} else if (params.diffusion.block_length) {
603632
diff_params.schedule = BLOCK_BASED;
604633
diff_params.block_length = params.diffusion.block_length;
605634
}
635+
606636
diff_params.mask_token_id = mask_token_id;
607637
diff_params.seed = params.sampling.seed;
608638
diff_params.temperature = params.sampling.temp;
@@ -613,64 +643,51 @@ int main(int argc, char ** argv) {
613643
diff_params.top_k = params.sampling.top_k;
614644
diff_params.visual_mode = params.diffusion.visual_mode;
615645
diff_params.add_gumbel_noise = params.diffusion.add_gumbel_noise;
616-
diff_params.step_callback = diffusion_step_callback;
617-
// 交互模式
618-
if (params.prompt.empty()) {
619-
LOG_INF("进入交互模式,输入 exit 或 quit 退出。\n");
620-
std::string user_prompt;
621-
while (true) {
622-
printf("\n请输入问题:");
623-
std::getline(std::cin, user_prompt);
624-
if (user_prompt.empty()) continue;
625-
if (user_prompt == "exit" || user_prompt == "quit") break;
626-
std::string formatted_prompt = format_input_text(user_prompt, params.system_prompt, params.enable_chat_template, model);
627-
std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt, true, true);
628-
int n_input = input_tokens.size();
629-
if (n_input >= params.n_ctx) {
630-
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
631-
continue;
632-
}
633-
std::vector<llama_token> output_tokens(params.n_ubatch);
634-
callback_data cb_data = { &diff_params, vocab, n_input };
635-
diff_params.step_callback_user_data = &cb_data;
636-
int n_generated = 0;
637-
diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, diff_params, n_generated);
638-
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
639-
std::string output_data = common_detokenize(vocab, output_tokens, false);
640-
//LOG_INF("\n%s\n", output_data.c_str());
641-
}
642-
llama_free(ctx);
643-
llama_model_free(model);
644-
llama_backend_free();
645-
return 0;
646+
647+
diff_params.step_callback = diffusion_step_callback;
648+
callback_data cb_data = { &diff_params, vocab, n_input };
649+
diff_params.step_callback_user_data = &cb_data;
650+
651+
const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" };
652+
const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" };
653+
const char * alg_name =
654+
(diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN";
655+
const char * sched_name =
656+
(diff_params.schedule >= 0 && diff_params.schedule <= 1) ? sched_names[diff_params.schedule] : "UNKNOWN";
657+
658+
LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
659+
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", diff_params.steps);
660+
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "max_length", diff_params.max_length);
661+
LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name);
662+
LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name);
663+
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature);
664+
if (diff_params.schedule == TIMESTEP_BASED) {
665+
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps);
666+
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp);
646667
}
647-
// 非交互模式,原有逻辑
648-
std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model);
649-
std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt, true, true);
650-
int n_input = input_tokens.size();
651-
if (n_input >= params.n_ctx) {
652-
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
653-
llama_free(ctx);
654-
llama_model_free(model);
655-
return 1;
668+
if (diff_params.schedule == BLOCK_BASED) {
669+
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length);
670+
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale);
656671
}
657-
std::vector<llama_token> output_tokens(params.n_ubatch);
658-
callback_data cb_data = { &diff_params, vocab, n_input };
659-
diff_params.step_callback_user_data = &cb_data;
660-
int n_generated = 0;
672+
661673
diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, diff_params, n_generated);
674+
662675
if (n_generated > 0) {
663676
if (visual_mode) {
677+
//clear screen and move cursor to top-left
664678
LOG_INF("\033[2J\033[H");
665679
}
680+
666681
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
667682
std::string output_data = common_detokenize(vocab, output_tokens, false);
668683
LOG_INF("\n%s\n", output_data.c_str());
669684
} else {
670685
LOG_INF("Error: diffusion generation failed\n");
671686
}
687+
672688
llama_free(ctx);
673689
llama_model_free(model);
674690
llama_backend_free();
691+
675692
return 0;
676-
}
693+
}

0 commit comments

Comments
 (0)