Skip to content

Commit 17855ff

Browse files
committed
llama : add normalized field to llama_token_data_array struct
This commit adds a 'normalized' field to the llama_token_data_array struct to indicate whether the probabilities have been computed and normalized from the logits. The motivation for this change is to avoid redundant normalization calls in the sampling code, as the softmax calculation can be expensive depending on the size of the llama_token_data array. Samplers that modify logits or filter tokens (change the size) must set normalized to false to invalidate cached probabilities. Samplers that compute probabilities set it to true after normalization.
1 parent df1b612 commit 17855ff

File tree

7 files changed

+69
-14
lines changed

7 files changed

+69
-14
lines changed

common/sampling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ struct common_sampler {
126126
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
127127
}
128128

129-
cur_p = { cur.data(), cur.size(), -1, false };
129+
cur_p = { cur.data(), cur.size(), false, -1, false };
130130
}
131131
};
132132

@@ -360,7 +360,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
360360
// check if it the sampled token fits the grammar
361361
{
362362
llama_token_data single_token_data = { id, 1.0f, 0.0f };
363-
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
363+
llama_token_data_array single_token_data_array = { &single_token_data, 1, false, -1, false };
364364

365365
llama_sampler_apply(grmr, &single_token_data_array);
366366

examples/diffusion/diffusion-cli.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ static void diffusion_generate(llama_context * ctx,
404404
llama_token_data_array cur_p = {
405405
candidates.data(),
406406
(size_t) n_vocab,
407+
false, // normalized
407408
-1,
408409
false,
409410
};
@@ -429,6 +430,7 @@ static void diffusion_generate(llama_context * ctx,
429430
llama_token_data_array cur_p = {
430431
candidates.data(),
431432
candidates.size(),
433+
false, // normalized
432434
-1,
433435
false,
434436
};
@@ -472,6 +474,7 @@ static void diffusion_generate(llama_context * ctx,
472474
llama_token_data_array conf_array = {
473475
conf_candidates.data(),
474476
conf_candidates.size(),
477+
false,
475478
-1,
476479
false,
477480
};

examples/speculative/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ int main(int argc, char ** argv) {
269269

270270
LOG_DBG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
271271
float r = u_dist(rng);
272-
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
272+
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), false, LLAMA_TOKEN_NULL, true };
273273

274274
//GGML_ASSERT(dist_tgt.size <= dist_dft.size);
275275

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ extern "C" {
205205
// NOTE: this pointer can be modified by the samplers
206206
llama_token_data * data;
207207
size_t size;
208+
bool normalized; // true if the probabilities (llama_token_data.p) have been computed
208209
int64_t selected; // this is the index in the data array (i.e. not the token id)
209210
bool sorted; // note: do not assume the data is sorted - always check this flag
210211
} llama_token_data_array;

src/llama-grammar.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
11561156
for (const auto & reject : rejects) {
11571157
cur_p->data[reject.index].logit = -INFINITY;
11581158
}
1159+
cur_p->normalized = false;
11591160
}
11601161

11611162
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {

src/llama-sampling.cpp

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ static void llama_log_softmax(float * array, size_t size) {
260260
*/
261261

262262
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
263+
cur_p->normalized = false;
263264
if (temp <= 0.0f) {
264265
// find the token with the highest logit and set the rest to -inf
265266
size_t max_i = 0;
@@ -309,6 +310,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_s
309310
for (size_t i = 0; i < cur_p->size; ++i) {
310311
cur_p->data[i].p /= cum_sum;
311312
}
313+
cur_p->normalized = true;
312314
}
313315

314316
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
@@ -328,6 +330,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
328330
}
329331

330332
cur_p->size = k;
333+
cur_p->normalized = false;
331334
}
332335

333336
static uint32_t get_rng_seed(uint32_t seed) {
@@ -422,6 +425,7 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
422425
llama_token_data_array cur_p = {
423426
/* .data = */ cur.data(),
424427
/* .size = */ cur.size(),
428+
/* .normalized = */ false,
425429
/* .selected = */ -1,
426430
/* .sorted = */ false,
427431
};
@@ -614,6 +618,23 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
614618

615619
if (cur_p->size == 1) {
616620
cur_p->data[0].p = 1.0f;
621+
cur_p->normalized = true;
622+
return;
623+
}
624+
625+
if (cur_p->normalized) {
626+
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
627+
const double rnd = dist(ctx->rng);
628+
double sum_run = 0.0f;
629+
630+
for (size_t i = 0; i < cur_p->size; ++i) {
631+
sum_run += cur_p->data[i].p;
632+
if (sum_run >= rnd) {
633+
cur_p->selected = i;
634+
return;
635+
}
636+
}
637+
cur_p->selected = cur_p->size - 1;
617638
return;
618639
}
619640

@@ -663,13 +684,15 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
663684
if (!found) {
664685
cur_p->selected = cur_p->size - 1;
665686
}
687+
cur_p->normalized = true;
666688
#else
667689
// for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
668690
for (size_t i = 0; i < cur_p->size; ++i) {
669691
cur_p->data[i].p /= sum_cum;
670692
}
671693
672694
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
695+
cur_p->normalized = true;
673696
#endif
674697
}
675698

@@ -780,7 +803,9 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
780803
return;
781804
}
782805

783-
llama_sampler_softmax_impl(cur_p, false);
806+
if (!cur_p->normalized) {
807+
llama_sampler_softmax_impl(cur_p, false);
808+
}
784809

785810
size_t k = cur_p->size;
786811
auto * pdata = cur_p->data;
@@ -826,6 +851,7 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
826851
}
827852

828853
cur_p->size = last_idx;
854+
cur_p->normalized = false;
829855
}
830856

831857
static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
@@ -897,6 +923,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
897923
if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
898924
std::copy(filtered_tokens.begin(), filtered_tokens.end(), cur_p->data);
899925
cur_p->size = filtered_tokens.size();
926+
cur_p->normalized = false;
900927
min_p_applied = true;
901928
}
902929
}
@@ -919,6 +946,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
919946

920947
// Resize the output vector to keep only the matching tokens
921948
cur_p->size = i;
949+
cur_p->normalized = false;
922950
}
923951
}
924952

@@ -971,7 +999,9 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
971999
}
9721000

9731001
// Compute the softmax of logits and calculate entropy
974-
llama_sampler_softmax_impl(cur_p, true);
1002+
if (!cur_p->normalized) {
1003+
llama_sampler_softmax_impl(cur_p, true);
1004+
}
9751005

9761006
float entropy = 0.0f;
9771007
for (size_t i = 0; i < cur_p->size; ++i) {
@@ -1019,6 +1049,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
10191049
std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
10201050
cur_p->size = cur_p_new.size();
10211051
cur_p->sorted = false;
1052+
cur_p->normalized = false;
10221053
}
10231054

10241055
static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
@@ -1120,7 +1151,9 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
11201151
// Calculate maximum possible entropy
11211152
float max_entropy = -logf(1.0f / cur_p->size);
11221153

1123-
llama_sampler_softmax_impl(cur_p, true);
1154+
if (!cur_p->normalized) {
1155+
llama_sampler_softmax_impl(cur_p, true);
1156+
}
11241157

11251158
// Calculate entropy of the softmax probabilities
11261159
float entropy = 0.0f;
@@ -1162,6 +1195,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
11621195
for (size_t i = 0; i < cur_p->size; ++i) {
11631196
cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
11641197
}
1198+
cur_p->normalized = true;
11651199

11661200
#ifdef DEBUG
11671201
// Print the updated top 25 probabilities after temperature scaling
@@ -1236,7 +1270,9 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
12361270
return;
12371271
}
12381272

1239-
llama_sampler_softmax_impl(cur_p, true);
1273+
if (!cur_p->normalized) {
1274+
llama_sampler_softmax_impl(cur_p, true);
1275+
}
12401276

12411277
int pos_last = 0;
12421278

@@ -1251,6 +1287,7 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
12511287
if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
12521288
cur_p->data += pos_last;
12531289
cur_p->size -= pos_last;
1290+
cur_p->normalized = false;
12541291
}
12551292
}
12561293

@@ -1327,7 +1364,9 @@ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*s
13271364
static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
13281365
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
13291366

1330-
llama_sampler_softmax_impl(cur_p, true);
1367+
if (!cur_p->normalized) {
1368+
llama_sampler_softmax_impl(cur_p, true);
1369+
}
13311370

13321371
// Estimate s_hat using the most probable m tokens
13331372
float s_hat = 0.0;
@@ -1433,7 +1472,9 @@ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler *
14331472
static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
14341473
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
14351474

1436-
llama_sampler_softmax_impl(cur_p, true);
1475+
if (!cur_p->normalized) {
1476+
llama_sampler_softmax_impl(cur_p, true);
1477+
}
14371478

14381479
// Truncate the words with surprise values greater than mu
14391480
cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
@@ -1775,6 +1816,7 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
17751816
}
17761817

17771818
cur_p->sorted = false;
1819+
cur_p->normalized = false;
17781820
}
17791821

17801822
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
@@ -2193,6 +2235,7 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
21932235
}
21942236

21952237
cur_p->sorted = false;
2238+
cur_p->normalized = false;
21962239
}
21972240

21982241
static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
@@ -2344,6 +2387,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
23442387
}
23452388

23462389
if (ctx->to_search.empty()) {
2390+
cur_p->normalized = false;
23472391
return;
23482392
}
23492393

@@ -2356,6 +2400,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
23562400
}
23572401
}
23582402
}
2403+
cur_p->normalized = false;
23592404
}
23602405

23612406
static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
@@ -2408,7 +2453,9 @@ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smp
24082453
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
24092454
auto * ctx = (llama_sampler_infill *) smpl->ctx;
24102455

2411-
llama_sampler_softmax_impl(cur_p, true);
2456+
if (!cur_p->normalized) {
2457+
llama_sampler_softmax_impl(cur_p, true);
2458+
}
24122459

24132460
#if defined(GGML_DEBUG_SAMPLER_INFILL)
24142461
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
@@ -2457,6 +2504,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
24572504
for (size_t i = 0; i < cur_p->size; ++i) {
24582505
cur_p->data[i].p /= p_sum;
24592506
}
2507+
cur_p->normalized = true;
24602508

24612509
return;
24622510
}
@@ -2542,6 +2590,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
25422590
cur_p->size = 1;
25432591
cur_p->data[0].id = ctx->vocab->token_eot();
25442592
cur_p->data[0].logit = 1.0f;
2593+
cur_p->normalized = true;
25452594

25462595
return;
25472596
}
@@ -2579,6 +2628,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
25792628

25802629
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
25812630
}
2631+
cur_p->normalized = true;
25822632

25832633
#undef LOG_DBG_CUR
25842634
}

tests/test-sampling.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ struct sampler_tester {
2828
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
2929
}
3030

31-
cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
31+
cur_p = llama_token_data_array { cur.data(), cur.size(), false, -1, false };
3232
}
3333

3434
sampler_tester(const std::vector<float> & probs, const std::vector<float> & probs_expected) : probs_expected(probs_expected) {
@@ -38,7 +38,7 @@ struct sampler_tester {
3838
cur.emplace_back(llama_token_data{token_id, logit, probs[token_id]});
3939
}
4040

41-
cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
41+
cur_p = llama_token_data_array { cur.data(), cur.size(), false, -1, false };
4242
}
4343

4444
void apply(llama_sampler * sampler) {
@@ -270,13 +270,13 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
270270
static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) {
271271
std::vector<llama_token_data> cur(data.size());
272272
std::copy(data.begin(), data.end(), cur.begin());
273-
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
273+
llama_token_data_array cur_p = { cur.data(), cur.size(), false, -1, false };
274274
llama_sampler_apply(cnstr, &cur_p);
275275
llama_sampler_reset(cnstr);
276276
const int64_t t_start = ggml_time_us();
277277
for (int i = 0; i < n_iter; i++) {
278278
std::copy(data.begin(), data.end(), cur.begin());
279-
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
279+
llama_token_data_array cur_p = { cur.data(), cur.size(), false, -1, false };
280280
llama_sampler_apply(cnstr, &cur_p);
281281
llama_sampler_reset(cnstr);
282282
}

0 commit comments

Comments
 (0)