Skip to content

Commit 4c5db30

Browse files
committed
More cleanup
1 parent efd7d56 commit 4c5db30

File tree

4 files changed

+60
-92
lines changed

4 files changed

+60
-92
lines changed

examples/diffusion/diffusion-cli.cpp

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
#include "arg.h"
1010
#include "chat.h"
1111
#include "common.h"
12+
#include "diffusion.h"
1213
#include "llama.h"
1314
#include "log.h"
14-
#include "diffusion.h"
1515

1616
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
1717
if (!use_chat_template) {
@@ -36,7 +36,6 @@ struct callback_data {
3636
const common_params_diffusion * diff_params;
3737
const llama_vocab * vocab;
3838
int32_t n_input;
39-
llama_token mask_token_id; // Store mask token separately since it's not in diffusion params
4039
};
4140

4241
static bool diffusion_step_callback(int32_t step, int32_t total_steps, const llama_token * tokens, int32_t n_tokens,
@@ -46,8 +45,8 @@ static bool diffusion_step_callback(int32_t step, int32_t total_steps, const lla
4645
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
4746
int progress_percent = (step * 100) / total_steps;
4847
int progress_bars = (step * 50) / total_steps;
49-
std::cerr << "diffusion step: " << step << "/" << total_steps << " [" << std::string(progress_bars, '=')
50-
<< std::string(50 - progress_bars, ' ') << "] " << progress_percent << "%\n";
48+
std::cerr << "\rdiffusion step: " << step << "/" << total_steps << " [" << std::string(progress_bars, '=')
49+
<< std::string(50 - progress_bars, ' ') << "] " << progress_percent << "%";
5150
};
5251

5352
if (data->diff_params->visual_mode) {
@@ -56,11 +55,13 @@ static bool diffusion_step_callback(int32_t step, int32_t total_steps, const lla
5655

5756
print_progress_bar(step, total_steps);
5857

58+
std::cerr << "\n";
59+
5960
std::string current_text = " ";
6061

6162
for (int32_t i = data->n_input; i < n_tokens; i++) {
6263
std::string token_str;
63-
if (tokens[i] != data->mask_token_id) {
64+
if (tokens[i] != llama_vocab_mask(data->vocab)) {
6465
char piece[256];
6566
int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
6667
if (n_chars > 0) {
@@ -135,9 +136,8 @@ int main(int argc, char ** argv) {
135136
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
136137

137138
std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
138-
true, // add_special tokens
139-
true // parse_special
140-
);
139+
/*add special tokens*/ true,
140+
/*parse special*/ true);
141141
int n_input = input_tokens.size();
142142

143143
if (n_input >= params.n_ctx) {
@@ -148,14 +148,14 @@ int main(int argc, char ** argv) {
148148
}
149149

150150
struct diffusion_params ldiff_params = diffusion_default_params();
151-
ldiff_params.steps = params.diffusion.steps;
152-
ldiff_params.eps = params.diffusion.eps;
153-
ldiff_params.temperature = params.sampling.temp;
154-
ldiff_params.top_p = params.sampling.top_p;
155-
ldiff_params.top_k = params.sampling.top_k;
156-
ldiff_params.algorithm = static_cast<enum diffusion_algorithm>(params.diffusion.algorithm);
157-
ldiff_params.alg_temp = params.diffusion.alg_temp;
158-
ldiff_params.seed = params.sampling.seed;
151+
ldiff_params.steps = params.diffusion.steps;
152+
ldiff_params.eps = params.diffusion.eps;
153+
ldiff_params.temperature = params.sampling.temp;
154+
ldiff_params.top_p = params.sampling.top_p;
155+
ldiff_params.top_k = params.sampling.top_k;
156+
ldiff_params.algorithm = static_cast<enum diffusion_algorithm>(params.diffusion.algorithm);
157+
ldiff_params.alg_temp = params.diffusion.alg_temp;
158+
ldiff_params.seed = params.sampling.seed;
159159

160160
llama_token mask_token_id = llama_vocab_mask(vocab);
161161
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
@@ -169,38 +169,36 @@ int main(int argc, char ** argv) {
169169

170170
ldiff_params.mask_token_id = mask_token_id;
171171

172-
callback_data cb_data = { &params.diffusion, vocab, n_input, mask_token_id };
172+
callback_data cb_data = { &params.diffusion, vocab, n_input };
173173

174174
ldiff_params.step_callback = diffusion_step_callback;
175175
ldiff_params.step_callback_user_data = &cb_data;
176176

177-
int32_t n_generated = 0;
177+
int32_t n_generated = 0;
178178

179-
int64_t t1 = ggml_time_us();
180-
llama_token * generated = diffusion_generate(ctx, input_tokens.data(), n_input, params.diffusion.max_length,
181-
ldiff_params, &n_generated);
179+
int64_t t1 = ggml_time_us();
180+
std::vector<llama_token> output_tokens(params.diffusion.max_length);
181+
diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.diffusion.max_length,
182+
ldiff_params, &n_generated);
182183
int64_t t2 = ggml_time_us();
183-
if (params.diffusion.visual_mode) {
184-
std::cerr << "\033[2J\033[H"; // Clear screen and move cursor to top-left
185-
} else {
186-
std::cerr << "\r" << std::string(80, ' ') << "\r" << std::flush;
187-
}
188-
189-
if (generated && n_generated > 0) {
190-
std::vector<llama_token> output_tokens(generated + n_input, generated + n_generated);
191184

185+
if (n_generated > 0) {
186+
if (params.diffusion.visual_mode) {
187+
//clear screen and move cursor to top-left
188+
std::cerr << "\033[2J\033[H";
189+
}
190+
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
192191
std::string output_data = common_detokenize(vocab, output_tokens, false);
193-
std::cout << output_data << std::endl;
194-
195-
delete[] generated;
192+
std::cout << "\n" << output_data << "\n";
196193
} else {
197-
std::cerr << "Error: diffusion generation failed" << std::endl;
194+
std::cerr << "Error: diffusion generation failed\n";
198195
llama_free(ctx);
199196
llama_model_free(model);
200197
return 1;
201198
}
202199

203-
std::cerr << "diffusion time: " << (t2 - t1)/1000.0 << "ms time per step: " << (t2 - t1)/1000.0/params.diffusion.steps << "ms" << std::endl;
200+
std::cerr << "diffusion time: " << (t2 - t1) / 1000.0
201+
<< "ms time per step: " << (t2 - t1) / 1000.0 / params.diffusion.steps << "ms" << std::endl;
204202

205203
llama_free(ctx);
206204
llama_model_free(model);

examples/diffusion/diffusion.cpp

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#include "diffusion.h"
2-
#include "llama.h"
3-
#include "log.h"
42

53
#include <algorithm>
64
#include <cmath>
75
#include <limits>
86
#include <random>
97
#include <vector>
108

9+
#include "llama.h"
10+
#include "log.h"
11+
1112
struct diffusion_params diffusion_default_params(void) {
1213
struct diffusion_params params = {};
1314
params.steps = 64;
@@ -24,30 +25,15 @@ struct diffusion_params diffusion_default_params(void) {
2425
return params;
2526
}
2627

27-
llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_tokens, int32_t n_input,
28-
int32_t max_length, struct diffusion_params params, int32_t * n_generated) {
29-
if (!ctx || !input_tokens || n_input <= 0 || max_length <= n_input) {
28+
void diffusion_generate(llama_context * ctx, const llama_token * input_tokens, llama_token * output_tokens,
29+
int32_t n_input, int32_t max_length, struct diffusion_params params, int32_t * n_generated) {
30+
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
3031
if (n_generated) {
3132
*n_generated = 0;
3233
}
33-
return nullptr;
3434
}
3535

3636
const llama_model * model = llama_get_model(ctx);
37-
if (!model) {
38-
if (n_generated) {
39-
*n_generated = 0;
40-
}
41-
return nullptr;
42-
}
43-
44-
llama_token * output_tokens = new llama_token[max_length];
45-
if (!output_tokens) {
46-
if (n_generated) {
47-
*n_generated = 0;
48-
}
49-
return nullptr;
50-
}
5137

5238
// Initialize with input and pad with mask tokens
5339
std::copy(input_tokens, input_tokens + n_input, output_tokens);
@@ -107,25 +93,20 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
10793
if (ret != 0) {
10894
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
10995
llama_batch_free(batch);
110-
delete[] output_tokens;
111-
if (n_generated) {
112-
*n_generated = 0;
113-
}
114-
return nullptr;
96+
return;
11597
}
11698

11799
float * raw_logits = llama_get_logits(ctx);
118100
if (!raw_logits) {
119101
LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
120102
llama_batch_free(batch);
121-
delete[] output_tokens;
122103
if (n_generated) {
123104
*n_generated = 0;
124105
}
125-
return nullptr;
106+
return;
126107
}
127108

128-
auto get_logits_for_pos = [&](int32_t pos) -> const float* {
109+
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
129110
return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
130111
};
131112

@@ -148,16 +129,16 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
148129

149130
for (int32_t pos : mask_positions) {
150131
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
151-
const float* pos_logits = get_logits_for_pos(pos);
132+
const float * pos_logits = get_logits_for_pos(pos);
152133
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
153-
candidates[token_id].id = token_id;
134+
candidates[token_id].id = token_id;
154135
candidates[token_id].logit = pos_logits[token_id];
155-
candidates[token_id].p = 0.0f;
136+
candidates[token_id].p = 0.0f;
156137
}
157138

158139
llama_token_data_array cur_p = {
159140
/* .data = */ candidates.data(),
160-
/* .size = */ (size_t)n_vocab, // Reset size to full vocab
141+
/* .size = */ (size_t) n_vocab, // Reset size to full vocab
161142
/* .selected = */ -1,
162143
/* .sorted = */ false,
163144
};
@@ -171,13 +152,13 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
171152
std::vector<llama_token> sampled_tokens(mask_positions.size());
172153

173154
for (size_t i = 0; i < mask_positions.size(); i++) {
174-
int32_t pos = mask_positions[i];
155+
int32_t pos = mask_positions[i];
175156
const float * pos_logits = get_logits_for_pos(pos);
176157

177158
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
178159
candidates[token_id].logit = pos_logits[token_id];
179-
candidates[token_id].p = 0.0f;
180-
candidates[token_id].id = token_id;
160+
candidates[token_id].p = 0.0f;
161+
candidates[token_id].id = token_id;
181162
}
182163

183164
llama_token_data_array cur_p = {
@@ -206,7 +187,6 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
206187

207188
sampled_tokens[i] = sampled_token;
208189
confidences.emplace_back(confidence, i);
209-
210190
}
211191

212192
int32_t num_transfer =
@@ -284,6 +264,4 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
284264
if (n_generated) {
285265
*n_generated = max_length;
286266
}
287-
288-
return output_tokens;
289267
}

examples/diffusion/diffusion.h

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
extern "C" {
77
#endif
88

9-
typedef bool (*diffusion_step_callback_t)(int32_t step, int32_t total_steps, const llama_token * tokens, int32_t n_tokens, void * user_data);
9+
typedef bool (*diffusion_step_callback_t)(int32_t step, int32_t total_steps, const llama_token * tokens,
10+
int32_t n_tokens, void * user_data);
1011

1112
enum diffusion_algorithm {
12-
DIFFUSION_ALG_ORIGIN = 0,
13+
DIFFUSION_ALG_ORIGIN = 0,
1314
DIFFUSION_ALG_MASKGIT_PLUS = 1,
14-
DIFFUSION_ALG_TOPK_MARGIN = 2,
15-
DIFFUSION_ALG_ENTROPY = 3,
15+
DIFFUSION_ALG_TOPK_MARGIN = 2,
16+
DIFFUSION_ALG_ENTROPY = 3,
1617
};
1718

1819
struct diffusion_params {
@@ -31,13 +32,8 @@ struct diffusion_params {
3132

3233
struct diffusion_params diffusion_default_params(void);
3334

34-
llama_token * diffusion_generate(
35-
llama_context * ctx,
36-
const llama_token * input_tokens,
37-
int32_t n_input,
38-
int32_t max_length,
39-
struct diffusion_params params,
40-
int32_t * n_generated);
35+
void diffusion_generate(llama_context * ctx, const llama_token * input_tokens, llama_token * output_tokens,
36+
int32_t n_input, int32_t max_length, struct diffusion_params params, int32_t * n_generated);
4137

4238
#ifdef __cplusplus
4339
}

src/llama-context.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
#include "llama-context.h"
22

3-
#include <algorithm>
4-
#include <cinttypes>
5-
#include <cmath>
6-
#include <cstring>
7-
#include <limits>
8-
#include <random>
9-
#include <stdexcept>
10-
#include <vector>
11-
123
#include "llama-impl.h"
134
#include "llama-batch.h"
145
#include "llama-io.h"
156
#include "llama-memory.h"
167
#include "llama-mmap.h"
178
#include "llama-model.h"
189

10+
#include <cinttypes>
11+
#include <cstring>
12+
#include <limits>
13+
#include <stdexcept>
14+
1915
//
2016
// llama_context
2117
//

0 commit comments

Comments
 (0)