Skip to content

Commit 42516c8

Browse files
committed
Review: move everything to diffusion-cli for now
1 parent 4a13243 commit 42516c8

File tree

4 files changed

+317
-339
lines changed

4 files changed

+317
-339
lines changed

examples/diffusion/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
set(TARGET llama-diffusion-cli)
2-
add_executable(${TARGET} diffusion-cli.cpp diffusion.cpp)
2+
add_executable(${TARGET} diffusion-cli.cpp)
33
install(TARGETS ${TARGET} RUNTIME)
44
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
55
target_compile_features(${TARGET} PRIVATE cxx_std_17)

examples/diffusion/diffusion-cli.cpp

Lines changed: 316 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,316 @@
11
#include "arg.h"
22
#include "chat.h"
33
#include "common.h"
4-
#include "diffusion.h"
54
#include "llama.h"
65
#include "log.h"
76

87
#include <limits.h>
98
#include <string>
109
#include <vector>
10+
#include <algorithm>
11+
#include <cmath>
12+
#include <limits>
13+
#include <random>
14+
15+
typedef bool (*diffusion_step_callback_t)(int32_t step,
16+
int32_t total_steps,
17+
const llama_token * tokens,
18+
int32_t n_tokens,
19+
void * user_data);
20+
21+
enum diffusion_alg {
22+
DIFFUSION_ALG_ORIGIN = 0,
23+
DIFFUSION_ALG_MASKGIT_PLUS = 1,
24+
DIFFUSION_ALG_TOPK_MARGIN = 2,
25+
DIFFUSION_ALG_ENTROPY = 3,
26+
};
27+
28+
struct diffusion_params {
29+
int32_t steps;
30+
float eps;
31+
float temperature;
32+
float top_p;
33+
int32_t top_k;
34+
llama_token mask_token_id;
35+
enum diffusion_alg algorithm;
36+
float alg_temp;
37+
diffusion_step_callback_t step_callback;
38+
void * step_callback_user_data;
39+
int32_t seed;
40+
};
41+
42+
43+
static diffusion_params diffusion_default_params() {
44+
diffusion_params params = {};
45+
params.steps = 64;
46+
params.eps = 1e-3f;
47+
params.temperature = 0.2f;
48+
params.top_p = 0.95f;
49+
params.top_k = 0;
50+
params.mask_token_id = LLAMA_TOKEN_NULL;
51+
params.algorithm = DIFFUSION_ALG_ORIGIN;
52+
params.alg_temp = 0.0f;
53+
params.step_callback = nullptr;
54+
params.step_callback_user_data = nullptr;
55+
params.seed = 0;
56+
return params;
57+
}
58+
59+
static void diffusion_generate(llama_context * ctx,
60+
const llama_token * input_tokens,
61+
llama_token * output_tokens,
62+
int32_t n_input,
63+
int32_t max_length,
64+
struct diffusion_params params,
65+
int32_t & n_generated) {
66+
67+
n_generated = 0;
68+
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
69+
return;
70+
}
71+
72+
const llama_model * model = llama_get_model(ctx);
73+
74+
// Initialize with input and pad with mask tokens
75+
std::copy(input_tokens, input_tokens + n_input, output_tokens);
76+
std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id);
77+
78+
std::mt19937 rng(params.seed);
79+
80+
std::vector<float> timesteps(params.steps + 1);
81+
for (int32_t i = 0; i <= params.steps; i++) {
82+
timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps);
83+
}
84+
85+
llama_set_causal_attn(ctx, false);
86+
87+
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
88+
89+
std::vector<llama_token_data> candidates(n_vocab);
90+
91+
std::vector<llama_token_data> conf_candidates;
92+
conf_candidates.reserve(max_length);
93+
94+
std::vector<int32_t> mask_positions;
95+
mask_positions.reserve(max_length);
96+
97+
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
98+
if (params.top_k > 0) {
99+
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
100+
}
101+
if (params.top_p < 1.0f) {
102+
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
103+
}
104+
if (params.temperature > 0.0f) {
105+
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
106+
}
107+
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
108+
109+
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
110+
111+
llama_batch batch = llama_batch_init(max_length, 0, 1);
112+
batch.n_tokens = max_length;
113+
114+
int64_t total_sampling_time = 0;
115+
int64_t total_time = 0;
116+
117+
int64_t time_start = ggml_time_us();
118+
for (int32_t step = 0; step < params.steps; step++) {
119+
if (params.step_callback) {
120+
if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) {
121+
break;
122+
}
123+
}
124+
125+
for (int32_t i = 0; i < max_length; i++) {
126+
batch.token[i] = output_tokens[i];
127+
batch.pos[i] = i;
128+
batch.n_seq_id[i] = 1;
129+
batch.seq_id[i][0] = 0;
130+
batch.logits[i] = 1;
131+
}
132+
133+
int ret = llama_decode(ctx, batch);
134+
if (ret != 0) {
135+
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
136+
break;
137+
}
138+
139+
float * raw_logits = llama_get_logits(ctx);
140+
if (!raw_logits) {
141+
LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
142+
break;
143+
}
144+
145+
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
146+
return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
147+
};
148+
149+
int64_t time_start_sampling = ggml_time_us();
150+
151+
mask_positions.clear();
152+
for (int32_t i = 0; i < max_length; i++) {
153+
if (output_tokens[i] == params.mask_token_id) {
154+
mask_positions.push_back(i);
155+
}
156+
}
157+
158+
if (mask_positions.empty()) {
159+
break;
160+
}
161+
162+
float t = timesteps[step];
163+
float s = timesteps[step + 1];
164+
165+
if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
166+
float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f;
167+
168+
for (int32_t pos : mask_positions) {
169+
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
170+
const float * pos_logits = get_logits_for_pos(pos);
171+
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
172+
candidates[token_id].id = token_id;
173+
candidates[token_id].logit = pos_logits[token_id];
174+
candidates[token_id].p = 0.0f;
175+
}
176+
177+
llama_token_data_array cur_p = {
178+
/* .data = */ candidates.data(),
179+
/* .size = */ (size_t) n_vocab, // Reset size to full vocab
180+
/* .selected = */ -1,
181+
/* .sorted = */ false,
182+
};
183+
184+
llama_sampler_apply(sampler, &cur_p);
185+
output_tokens[pos] = cur_p.data[cur_p.selected].id;
186+
}
187+
}
188+
} else {
189+
std::vector<std::pair<float, int32_t>> confidences;
190+
std::vector<llama_token> sampled_tokens(mask_positions.size());
191+
192+
for (size_t i = 0; i < mask_positions.size(); i++) {
193+
int32_t pos = mask_positions[i];
194+
const float * pos_logits = get_logits_for_pos(pos);
195+
196+
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
197+
candidates[token_id].logit = pos_logits[token_id];
198+
candidates[token_id].p = 0.0f;
199+
candidates[token_id].id = token_id;
200+
}
201+
202+
llama_token_data_array cur_p = {
203+
/* .data = */ candidates.data(),
204+
/* .size = */ candidates.size(),
205+
/* .selected = */ -1,
206+
/* .sorted = */ false,
207+
};
208+
209+
llama_sampler_apply(sampler, &cur_p);
210+
211+
llama_token sampled_token = cur_p.data[cur_p.selected].id;
212+
213+
float confidence = 0.0f;
214+
if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
215+
const float epsilon = 1e-10f;
216+
for (size_t j = 0; j < cur_p.size; j++) {
217+
float prob = cur_p.data[j].p;
218+
confidence += prob * logf(prob + epsilon);
219+
}
220+
} else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
221+
confidence = cur_p.data[0].p - cur_p.data[1].p;
222+
} else {
223+
confidence = cur_p.data[cur_p.selected].p;
224+
}
225+
226+
sampled_tokens[i] = sampled_token;
227+
confidences.emplace_back(confidence, i);
228+
}
229+
230+
int32_t num_transfer =
231+
(step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size();
232+
233+
if (num_transfer > 0) {
234+
if (params.alg_temp == 0.0f) {
235+
std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(),
236+
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
237+
if (a.first != b.first) {
238+
return a.first > b.first;
239+
}
240+
return a.second < b.second;
241+
});
242+
} else {
243+
conf_candidates.clear();
244+
245+
for (int32_t pos = 0; pos < max_length; pos++) {
246+
float conf_logit = -std::numeric_limits<float>::infinity();
247+
248+
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
249+
if (it != mask_positions.end()) {
250+
size_t mask_idx = std::distance(mask_positions.begin(), it);
251+
conf_logit = confidences[mask_idx].first / params.alg_temp; // Apply temperature scaling
252+
}
253+
254+
conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f });
255+
}
256+
257+
llama_token_data_array conf_array = {
258+
/* .data = */ conf_candidates.data(),
259+
/* .size = */ conf_candidates.size(),
260+
/* .selected = */ -1,
261+
/* .sorted = */ false,
262+
};
263+
264+
for (int32_t i = 0; i < num_transfer; i++) {
265+
// Apply distribution sampler to get selected index
266+
llama_sampler_apply(dist_sampler, &conf_array);
267+
int selected_idx = conf_array.selected;
268+
confidences[i].second = conf_candidates[selected_idx].id;
269+
270+
conf_candidates[selected_idx].p = 0.0f;
271+
conf_array.selected = -1;
272+
}
273+
}
274+
275+
if (params.alg_temp == 0.0f) {
276+
// Deterministic - use confidence order
277+
for (int32_t i = 0; i < num_transfer; i++) {
278+
int32_t mask_idx = confidences[i].second;
279+
int32_t pos = mask_positions[mask_idx];
280+
llama_token token = sampled_tokens[mask_idx];
281+
output_tokens[pos] = token;
282+
}
283+
} else {
284+
for (int32_t i = 0; i < num_transfer; i++) {
285+
int32_t pos = confidences[i].second;
286+
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
287+
if (it != mask_positions.end()) {
288+
int32_t mask_idx = std::distance(mask_positions.begin(), it);
289+
output_tokens[pos] = sampled_tokens[mask_idx];
290+
}
291+
}
292+
}
293+
}
294+
}
295+
int64_t time_end_sampling = ggml_time_us();
296+
total_sampling_time += time_end_sampling - time_start_sampling;
297+
}
298+
int64_t time_end = ggml_time_us();
299+
total_time += time_end - time_start;
300+
301+
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
302+
total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps);
303+
304+
305+
llama_batch_free(batch);
306+
llama_sampler_free(sampler);
307+
llama_sampler_free(dist_sampler);
308+
309+
n_generated = max_length;
310+
}
311+
312+
313+
11314

12315
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
13316
if (!use_chat_template) {
@@ -34,24 +337,24 @@ struct callback_data {
34337
int32_t n_input;
35338
};
36339

37-
static bool diffusion_step_callback(int32_t step
38-
, int32_t total_steps
39-
, const llama_token * tokens
40-
, int32_t n_tokens
41-
, void * user_data) {
340+
static bool diffusion_step_callback(int32_t step,
341+
int32_t total_steps,
342+
const llama_token * tokens,
343+
int32_t n_tokens,
344+
void * user_data) {
42345
(void)user_data;
43346

44347
callback_data * data = static_cast<callback_data *>(user_data);
45348

46349
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
47350
int progress_percent = (step * 100) / total_steps;
48351
int progress_bars = (step * 50) / total_steps;
49-
LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%"
50-
, step
51-
, total_steps
52-
, std::string(progress_bars, '=').c_str()
53-
, std::string(50 - progress_bars, ' ').c_str()
54-
, progress_percent);
352+
LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
353+
step,
354+
total_steps,
355+
std::string(progress_bars, '=').c_str(),
356+
std::string(50 - progress_bars, ' ').c_str(),
357+
progress_percent);
55358
};
56359

57360
if (data->diff_params->visual_mode) {
@@ -157,7 +460,7 @@ int main(int argc, char ** argv) {
157460
ldiff_params.temperature = params.sampling.temp;
158461
ldiff_params.top_p = params.sampling.top_p;
159462
ldiff_params.top_k = params.sampling.top_k;
160-
ldiff_params.algorithm = static_cast<enum diffusion_algorithm>(params.diffusion.algorithm);
463+
ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
161464
ldiff_params.alg_temp = params.diffusion.alg_temp;
162465
ldiff_params.seed = params.sampling.seed;
163466

0 commit comments

Comments
 (0)