Skip to content

Commit efef0a3

Browse files
committed
Make sampling fast
1 parent 05fe235 commit efef0a3

File tree

2 files changed

+16
-26
lines changed

2 files changed

+16
-26
lines changed

examples/diffusion/diffusion-cli.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ int main(int argc, char ** argv) {
177177
int32_t n_generated = 0;
178178
llama_token * generated = diffusion_generate(ctx, input_tokens.data(), n_input, params.diffusion.max_length,
179179
ldiff_params, &n_generated);
180-
181180
if (params.diffusion.visual_mode) {
182181
std::cerr << "\033[2J\033[H"; // Clear screen and move cursor to top-left
183182
} else {

examples/diffusion/diffusion.cpp

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
6464

6565
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
6666

67-
std::vector<llama_token_data> candidates;
68-
candidates.reserve(n_vocab);
67+
std::vector<llama_token_data> candidates(n_vocab);
68+
6969
std::vector<llama_token_data> conf_candidates;
7070
conf_candidates.reserve(max_length);
71+
7172
std::vector<int32_t> mask_positions;
7273
mask_positions.reserve(max_length);
7374

@@ -124,18 +125,9 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
124125
return nullptr;
125126
}
126127

127-
std::vector<float> shifted_logits(max_length * n_vocab);
128-
129-
//Move logits to left, because decode path will shift logits to right
130-
// Position 0 keeps its own logits: shifted_logits[0] = raw_logits[0]
131-
std::copy(raw_logits, raw_logits + n_vocab, shifted_logits.data());
132-
133-
// Positions 1+ get logits from previous position: shifted_logits[i] = raw_logits[i-1]
134-
for (int32_t i = 1; i < max_length; i++) {
135-
std::copy(raw_logits + (i - 1) * n_vocab, raw_logits + i * n_vocab, shifted_logits.data() + i * n_vocab);
136-
}
137-
138-
float * logits = shifted_logits.data();
128+
auto get_logits_for_pos = [&](int32_t pos) -> const float* {
129+
return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
130+
};
139131

140132
mask_positions.clear();
141133
for (int32_t i = 0; i < max_length; i++) {
@@ -156,14 +148,16 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
156148

157149
for (int32_t pos : mask_positions) {
158150
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
159-
candidates.clear();
151+
const float* pos_logits = get_logits_for_pos(pos);
160152
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
161-
candidates.emplace_back(llama_token_data{ token_id, logits[pos * n_vocab + token_id], 0.0f });
153+
candidates[token_id].id = token_id;
154+
candidates[token_id].logit = pos_logits[token_id];
155+
candidates[token_id].p = 0.0f;
162156
}
163157

164158
llama_token_data_array cur_p = {
165159
/* .data = */ candidates.data(),
166-
/* .size = */ candidates.size(),
160+
/* .size = */ (size_t)n_vocab, // Reset size to full vocab
167161
/* .selected = */ -1,
168162
/* .sorted = */ false,
169163
};
@@ -173,19 +167,17 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
173167
}
174168
}
175169
} else {
176-
candidates.clear();
177-
candidates.shrink_to_fit();
178-
179170
std::vector<std::pair<float, int32_t>> confidences;
180171
std::vector<llama_token> sampled_tokens(mask_positions.size());
181172

182173
for (size_t i = 0; i < mask_positions.size(); i++) {
183174
int32_t pos = mask_positions[i];
184-
float * pos_logits = logits + pos * n_vocab;
175+
const float * pos_logits = get_logits_for_pos(pos);
185176

186-
candidates.clear();
187177
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
188-
candidates.emplace_back(llama_token_data{ token_id, pos_logits[token_id], 0.0f });
178+
candidates[token_id].logit = pos_logits[token_id];
179+
candidates[token_id].p = 0.0f;
180+
candidates[token_id].id = token_id;
189181
}
190182

191183
llama_token_data_array cur_p = {
@@ -207,15 +199,14 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
207199
confidence += prob * logf(prob + epsilon);
208200
}
209201
} else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
210-
std::partial_sort(cur_p.data, cur_p.data + 2, cur_p.data + cur_p.size,
211-
[](const llama_token_data & a, const llama_token_data & b) { return a.p > b.p; });
212202
confidence = cur_p.data[0].p - cur_p.data[1].p;
213203
} else {
214204
confidence = cur_p.data[cur_p.selected].p;
215205
}
216206

217207
sampled_tokens[i] = sampled_token;
218208
confidences.emplace_back(confidence, i);
209+
219210
}
220211

221212
int32_t num_transfer =

0 commit comments

Comments
 (0)