Skip to content

Commit ae0a627

Browse files
committed
First attempt at adding extra decay parameter
1 parent 2baf077 commit ae0a627

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

common/speculative.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "sampling.h"
66

77
#include <cstring>
8+
#include <cmath>
89
#include <algorithm>
910

1011
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
@@ -153,6 +154,13 @@ llama_tokens common_speculative_gen_draft(
153154

154155
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
155156

157+
// Extract parameters packed in p_min (format: 0.pmin_pdecay_nmin)
158+
const float p_min = floorf(params.p_min * 100) / 100; // First 2 decimal places
159+
const float p_decay = floorf(params.p_min * 10000) / 100 - p_min; // Next 2 decimal places
160+
const int n_min = roundf((params.p_min * 100000) - (p_min * 100000) - (p_decay * 1000)); // Last digit
161+
162+
printf("p_min=%f, p_decay=%f, n_min=%d\n", p_min, p_decay, n_min);
163+
156164
// reuse as much as possible from the old draft context
157165
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
158166
for (int i = 0; i < (int) prompt.size(); ++i) {
@@ -239,6 +247,8 @@ llama_tokens common_speculative_gen_draft(
239247

240248
common_sampler_reset(smpl);
241249

250+
float sequence_p = 1.0;
251+
242252
// sample n_draft tokens from the draft model
243253
for (int i = 0; i < params.n_draft; ++i) {
244254
common_batch_clear(batch);
@@ -263,8 +273,14 @@ llama_tokens common_speculative_gen_draft(
263273
break;
264274
}
265275

276+
sequence_p *= cur_p->data[0].p;
277+
278+
const float threshold_p = p_min * pow(std::max((int) result.size() - std::max(n_min, 1), 1), -p_decay);
279+
280+
printf("sequence_p=%f, threshold_p=%f\n", sequence_p, threshold_p);
281+
266282
// only collect very high-confidence draft tokens
267-
if (cur_p->data[0].p < params.p_min) {
283+
if (sequence_p < threshold_p) {
268284
break;
269285
}
270286

@@ -276,5 +292,7 @@ llama_tokens common_speculative_gen_draft(
276292
prompt.push_back(id);
277293
}
278294

295+
printf("result.size()=%d, sequence_p=%f\n", result.size(), sequence_p);
296+
279297
return result;
280298
}

0 commit comments

Comments
 (0)