11#include " arg.h"
22#include " chat.h"
33#include " common.h"
4- #include " diffusion.h"
54#include " llama.h"
65#include " log.h"
76
87#include < limits.h>
98#include < string>
109#include < vector>
10+ #include < algorithm>
11+ #include < cmath>
12+ #include < limits>
13+ #include < random>
14+
15+ typedef bool (*diffusion_step_callback_t )(int32_t step,
16+ int32_t total_steps,
17+ const llama_token * tokens,
18+ int32_t n_tokens,
19+ void * user_data);
20+
21+ enum diffusion_alg {
22+ DIFFUSION_ALG_ORIGIN = 0 ,
23+ DIFFUSION_ALG_MASKGIT_PLUS = 1 ,
24+ DIFFUSION_ALG_TOPK_MARGIN = 2 ,
25+ DIFFUSION_ALG_ENTROPY = 3 ,
26+ };
27+
28+ struct diffusion_params {
29+ int32_t steps;
30+ float eps;
31+ float temperature;
32+ float top_p;
33+ int32_t top_k;
34+ llama_token mask_token_id;
35+ enum diffusion_alg algorithm;
36+ float alg_temp;
37+ diffusion_step_callback_t step_callback;
38+ void * step_callback_user_data;
39+ int32_t seed;
40+ };
41+
42+
43+ static diffusion_params diffusion_default_params () {
44+ diffusion_params params = {};
45+ params.steps = 64 ;
46+ params.eps = 1e-3f ;
47+ params.temperature = 0 .2f ;
48+ params.top_p = 0 .95f ;
49+ params.top_k = 0 ;
50+ params.mask_token_id = LLAMA_TOKEN_NULL;
51+ params.algorithm = DIFFUSION_ALG_ORIGIN;
52+ params.alg_temp = 0 .0f ;
53+ params.step_callback = nullptr ;
54+ params.step_callback_user_data = nullptr ;
55+ params.seed = 0 ;
56+ return params;
57+ }
58+
59+ static void diffusion_generate (llama_context * ctx,
60+ const llama_token * input_tokens,
61+ llama_token * output_tokens,
62+ int32_t n_input,
63+ int32_t max_length,
64+ struct diffusion_params params,
65+ int32_t & n_generated) {
66+
67+ n_generated = 0 ;
68+ if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
69+ return ;
70+ }
71+
72+ const llama_model * model = llama_get_model (ctx);
73+
74+ // Initialize with input and pad with mask tokens
75+ std::copy (input_tokens, input_tokens + n_input, output_tokens);
76+ std::fill (output_tokens + n_input, output_tokens + max_length, params.mask_token_id );
77+
78+ std::mt19937 rng (params.seed );
79+
80+ std::vector<float > timesteps (params.steps + 1 );
81+ for (int32_t i = 0 ; i <= params.steps ; i++) {
82+ timesteps[i] = 1 .0f - (float ) i / params.steps * (1 .0f - params.eps );
83+ }
84+
85+ llama_set_causal_attn (ctx, false );
86+
87+ int32_t n_vocab = llama_vocab_n_tokens (llama_model_get_vocab (model));
88+
89+ std::vector<llama_token_data> candidates (n_vocab);
90+
91+ std::vector<llama_token_data> conf_candidates;
92+ conf_candidates.reserve (max_length);
93+
94+ std::vector<int32_t > mask_positions;
95+ mask_positions.reserve (max_length);
96+
97+ struct llama_sampler * sampler = llama_sampler_chain_init (llama_sampler_chain_default_params ());
98+ if (params.top_k > 0 ) {
99+ llama_sampler_chain_add (sampler, llama_sampler_init_top_k (params.top_k ));
100+ }
101+ if (params.top_p < 1 .0f ) {
102+ llama_sampler_chain_add (sampler, llama_sampler_init_top_p (params.top_p , 1 ));
103+ }
104+ if (params.temperature > 0 .0f ) {
105+ llama_sampler_chain_add (sampler, llama_sampler_init_temp (params.temperature ));
106+ }
107+ llama_sampler_chain_add (sampler, llama_sampler_init_dist (params.seed ));
108+
109+ struct llama_sampler * dist_sampler = llama_sampler_init_dist (params.seed );
110+
111+ llama_batch batch = llama_batch_init (max_length, 0 , 1 );
112+ batch.n_tokens = max_length;
113+
114+ int64_t total_sampling_time = 0 ;
115+ int64_t total_time = 0 ;
116+
117+ int64_t time_start = ggml_time_us ();
118+ for (int32_t step = 0 ; step < params.steps ; step++) {
119+ if (params.step_callback ) {
120+ if (!params.step_callback (step, params.steps , output_tokens, max_length, params.step_callback_user_data )) {
121+ break ;
122+ }
123+ }
124+
125+ for (int32_t i = 0 ; i < max_length; i++) {
126+ batch.token [i] = output_tokens[i];
127+ batch.pos [i] = i;
128+ batch.n_seq_id [i] = 1 ;
129+ batch.seq_id [i][0 ] = 0 ;
130+ batch.logits [i] = 1 ;
131+ }
132+
133+ int ret = llama_decode (ctx, batch);
134+ if (ret != 0 ) {
135+ LOG_ERR (" %s: failed to decode at step %d, ret = %d\n " , __func__, step, ret);
136+ break ;
137+ }
138+
139+ float * raw_logits = llama_get_logits (ctx);
140+ if (!raw_logits) {
141+ LOG_ERR (" %s: failed to get logits at step %d\n " , __func__, step);
142+ break ;
143+ }
144+
145+ auto get_logits_for_pos = [&](int32_t pos) -> const float * {
146+ return pos == 0 ? raw_logits : raw_logits + (pos - 1 ) * n_vocab;
147+ };
148+
149+ int64_t time_start_sampling = ggml_time_us ();
150+
151+ mask_positions.clear ();
152+ for (int32_t i = 0 ; i < max_length; i++) {
153+ if (output_tokens[i] == params.mask_token_id ) {
154+ mask_positions.push_back (i);
155+ }
156+ }
157+
158+ if (mask_positions.empty ()) {
159+ break ;
160+ }
161+
162+ float t = timesteps[step];
163+ float s = timesteps[step + 1 ];
164+
165+ if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
166+ float p_transfer = (step < params.steps - 1 ) ? (1 .0f - s / t) : 1 .0f ;
167+
168+ for (int32_t pos : mask_positions) {
169+ if (std::uniform_real_distribution<float >(0 .0f , 1 .0f )(rng) < p_transfer) {
170+ const float * pos_logits = get_logits_for_pos (pos);
171+ for (int32_t token_id = 0 ; token_id < n_vocab; token_id++) {
172+ candidates[token_id].id = token_id;
173+ candidates[token_id].logit = pos_logits[token_id];
174+ candidates[token_id].p = 0 .0f ;
175+ }
176+
177+ llama_token_data_array cur_p = {
178+ /* .data = */ candidates.data (),
179+ /* .size = */ (size_t ) n_vocab, // Reset size to full vocab
180+ /* .selected = */ -1 ,
181+ /* .sorted = */ false ,
182+ };
183+
184+ llama_sampler_apply (sampler, &cur_p);
185+ output_tokens[pos] = cur_p.data [cur_p.selected ].id ;
186+ }
187+ }
188+ } else {
189+ std::vector<std::pair<float , int32_t >> confidences;
190+ std::vector<llama_token> sampled_tokens (mask_positions.size ());
191+
192+ for (size_t i = 0 ; i < mask_positions.size (); i++) {
193+ int32_t pos = mask_positions[i];
194+ const float * pos_logits = get_logits_for_pos (pos);
195+
196+ for (int32_t token_id = 0 ; token_id < n_vocab; token_id++) {
197+ candidates[token_id].logit = pos_logits[token_id];
198+ candidates[token_id].p = 0 .0f ;
199+ candidates[token_id].id = token_id;
200+ }
201+
202+ llama_token_data_array cur_p = {
203+ /* .data = */ candidates.data (),
204+ /* .size = */ candidates.size (),
205+ /* .selected = */ -1 ,
206+ /* .sorted = */ false ,
207+ };
208+
209+ llama_sampler_apply (sampler, &cur_p);
210+
211+ llama_token sampled_token = cur_p.data [cur_p.selected ].id ;
212+
213+ float confidence = 0 .0f ;
214+ if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
215+ const float epsilon = 1e-10f ;
216+ for (size_t j = 0 ; j < cur_p.size ; j++) {
217+ float prob = cur_p.data [j].p ;
218+ confidence += prob * logf (prob + epsilon);
219+ }
220+ } else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
221+ confidence = cur_p.data [0 ].p - cur_p.data [1 ].p ;
222+ } else {
223+ confidence = cur_p.data [cur_p.selected ].p ;
224+ }
225+
226+ sampled_tokens[i] = sampled_token;
227+ confidences.emplace_back (confidence, i);
228+ }
229+
230+ int32_t num_transfer =
231+ (step < params.steps - 1 ) ? (int32_t ) (mask_positions.size () * (1 .0f - s / t)) : mask_positions.size ();
232+
233+ if (num_transfer > 0 ) {
234+ if (params.alg_temp == 0 .0f ) {
235+ std::partial_sort (confidences.begin (), confidences.begin () + num_transfer, confidences.end (),
236+ [](const std::pair<float , int32_t > & a, const std::pair<float , int32_t > & b) {
237+ if (a.first != b.first ) {
238+ return a.first > b.first ;
239+ }
240+ return a.second < b.second ;
241+ });
242+ } else {
243+ conf_candidates.clear ();
244+
245+ for (int32_t pos = 0 ; pos < max_length; pos++) {
246+ float conf_logit = -std::numeric_limits<float >::infinity ();
247+
248+ auto it = std::find (mask_positions.begin (), mask_positions.end (), pos);
249+ if (it != mask_positions.end ()) {
250+ size_t mask_idx = std::distance (mask_positions.begin (), it);
251+ conf_logit = confidences[mask_idx].first / params.alg_temp ; // Apply temperature scaling
252+ }
253+
254+ conf_candidates.emplace_back (llama_token_data{ pos, conf_logit, 0 .0f });
255+ }
256+
257+ llama_token_data_array conf_array = {
258+ /* .data = */ conf_candidates.data (),
259+ /* .size = */ conf_candidates.size (),
260+ /* .selected = */ -1 ,
261+ /* .sorted = */ false ,
262+ };
263+
264+ for (int32_t i = 0 ; i < num_transfer; i++) {
265+ // Apply distribution sampler to get selected index
266+ llama_sampler_apply (dist_sampler, &conf_array);
267+ int selected_idx = conf_array.selected ;
268+ confidences[i].second = conf_candidates[selected_idx].id ;
269+
270+ conf_candidates[selected_idx].p = 0 .0f ;
271+ conf_array.selected = -1 ;
272+ }
273+ }
274+
275+ if (params.alg_temp == 0 .0f ) {
276+ // Deterministic - use confidence order
277+ for (int32_t i = 0 ; i < num_transfer; i++) {
278+ int32_t mask_idx = confidences[i].second ;
279+ int32_t pos = mask_positions[mask_idx];
280+ llama_token token = sampled_tokens[mask_idx];
281+ output_tokens[pos] = token;
282+ }
283+ } else {
284+ for (int32_t i = 0 ; i < num_transfer; i++) {
285+ int32_t pos = confidences[i].second ;
286+ auto it = std::find (mask_positions.begin (), mask_positions.end (), pos);
287+ if (it != mask_positions.end ()) {
288+ int32_t mask_idx = std::distance (mask_positions.begin (), it);
289+ output_tokens[pos] = sampled_tokens[mask_idx];
290+ }
291+ }
292+ }
293+ }
294+ }
295+ int64_t time_end_sampling = ggml_time_us ();
296+ total_sampling_time += time_end_sampling - time_start_sampling;
297+ }
298+ int64_t time_end = ggml_time_us ();
299+ total_time += time_end - time_start;
300+
301+ LOG_INF (" \n total time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n " ,
302+ total_time / 1000.0 , total_time / 1000.0 / params.steps , total_sampling_time / 1000.0 / params.steps );
303+
304+
305+ llama_batch_free (batch);
306+ llama_sampler_free (sampler);
307+ llama_sampler_free (dist_sampler);
308+
309+ n_generated = max_length;
310+ }
311+
312+
313+
11314
12315static std::string format_input_text (const std::string & prompt, bool use_chat_template, llama_model * model) {
13316 if (!use_chat_template) {
@@ -34,24 +337,24 @@ struct callback_data {
34337 int32_t n_input;
35338};
36339
37- static bool diffusion_step_callback (int32_t step
38- , int32_t total_steps
39- , const llama_token * tokens
40- , int32_t n_tokens
41- , void * user_data) {
340+ static bool diffusion_step_callback (int32_t step,
341+ int32_t total_steps,
342+ const llama_token * tokens,
343+ int32_t n_tokens,
344+ void * user_data) {
42345 (void )user_data;
43346
44347 callback_data * data = static_cast <callback_data *>(user_data);
45348
46349 auto print_progress_bar = [](int32_t step, int32_t total_steps) {
47350 int progress_percent = (step * 100 ) / total_steps;
48351 int progress_bars = (step * 50 ) / total_steps;
49- LOG_INF (" \r diffusion step: %d/%d [%s%s] %d%%"
50- , step
51- , total_steps
52- , std::string (progress_bars, ' =' ).c_str ()
53- , std::string (50 - progress_bars, ' ' ).c_str ()
54- , progress_percent);
352+ LOG_INF (" \r diffusion step: %d/%d [%s%s] %d%%" ,
353+ step,
354+ total_steps,
355+ std::string (progress_bars, ' =' ).c_str (),
356+ std::string (50 - progress_bars, ' ' ).c_str (),
357+ progress_percent);
55358 };
56359
57360 if (data->diff_params ->visual_mode ) {
@@ -157,7 +460,7 @@ int main(int argc, char ** argv) {
157460 ldiff_params.temperature = params.sampling .temp ;
158461 ldiff_params.top_p = params.sampling .top_p ;
159462 ldiff_params.top_k = params.sampling .top_k ;
160- ldiff_params.algorithm = static_cast <enum diffusion_algorithm >(params.diffusion .algorithm );
463+ ldiff_params.algorithm = static_cast <enum diffusion_alg >(params.diffusion .algorithm );
161464 ldiff_params.alg_temp = params.diffusion .alg_temp ;
162465 ldiff_params.seed = params.sampling .seed ;
163466
0 commit comments