Skip to content

Commit 8f41918

Browse files
committed
common : final touches
ggml-ci
1 parent 4eb126f commit 8f41918

File tree

4 files changed

+45
-24
lines changed

4 files changed

+45
-24
lines changed

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ struct common_params_sampling {
156156
};
157157

158158
struct common_params_speculative {
159-
int32_t n_ctx = 4096; // draft context size
159+
int32_t n_ctx = 0; // draft context size
160160
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
161161
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
162162
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)

common/speculative.cpp

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ llama_tokens common_speculative_gen_draft(
142142

143143
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
144144

145+
// reuse as much as possible from the old draft context
146+
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
145147
for (int i = 0; i < (int) prompt.size(); ++i) {
146148
int cur = 0;
147149
while (i_start + cur < (int) prompt_tgt.size() &&
@@ -166,6 +168,8 @@ llama_tokens common_speculative_gen_draft(
166168

167169
prompt.clear();
168170
} else {
171+
// this happens when a previous draft has been discarded (for example, due to being too small), but the
172+
// target model agreed with it. in this case, we simply pass back the previous results to save compute
169173
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
170174
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
171175
result.push_back(prompt[i]);
@@ -174,42 +178,51 @@ llama_tokens common_speculative_gen_draft(
174178
break;
175179
}
176180
}
181+
177182
return result;
178183
}
179184

180-
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
181-
llama_kv_cache_seq_rm (ctx, 0, reuse_i + reuse_n, -1);
182-
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
185+
if (reuse_i > 0) {
186+
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
187+
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
188+
189+
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
190+
}
191+
192+
if (reuse_n < (int) prompt.size()) {
193+
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
183194

184-
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
185-
prompt.erase(prompt.begin() + reuse_n, prompt.end());
195+
prompt.erase(prompt.begin() + reuse_n, prompt.end());
196+
}
186197
}
187198

199+
// prepare a batch to evaluate any new tokens in the prompt
188200
common_batch_clear(batch);
189201

190-
for (int i = i_start + reuse_n; i < (int) prompt_tgt.size(); ++i) {
202+
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
191203
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
192204
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
193205

194206
prompt.push_back(prompt_tgt[i]);
195207
}
196208

197-
const llama_pos n_past = prompt_tgt.size() - i_start;
198-
199-
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
200-
209+
// we should rarely end-up here during normal decoding
201210
if (batch.n_tokens > 0) {
202-
LOG_DBG("%s: draft batch: %s\n", __func__, string_from(ctx, batch).c_str());
211+
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
203212

204213
llama_decode(ctx, batch);
205214
}
206215

216+
const llama_pos n_past = prompt.size();
217+
218+
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
219+
207220
common_batch_clear(batch);
208221
common_batch_add (batch, id_last, n_past, { 0 }, true);
209222

210223
prompt.push_back(id_last);
211224

212-
LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(ctx, prompt).c_str());
225+
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
213226

214227
llama_decode(ctx, batch);
215228

common/speculative.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
struct common_speculative;
77

88
struct common_speculative_params {
9-
int n_draft = 16;
9+
int n_draft = 16; // max drafted tokens
1010
int n_reuse = 256;
1111

12-
float p_min = 0.9f;
12+
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
1313
};
1414

1515
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
@@ -21,9 +21,8 @@ bool common_speculative_are_compatible(
2121
const struct llama_context * ctx_dft);
2222

2323
// sample up to n_draft tokens and add them to the batch using the draft model
24-
//
2524
llama_tokens common_speculative_gen_draft(
26-
struct common_speculative * spec,
27-
struct common_speculative_params params,
28-
const llama_tokens & prompt,
29-
llama_token id_last);
25+
struct common_speculative * spec,
26+
struct common_speculative_params params,
27+
const llama_tokens & prompt,
28+
llama_token id_last);

examples/speculative-simple/speculative-simple.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,11 @@ int main(int argc, char ** argv) {
4646
ctx_tgt = llama_init_tgt.context;
4747

4848
// load the draft model
49-
params.model = params.speculative.model;
49+
params.model = params.speculative.model;
50+
params.n_ctx = params.speculative.n_ctx;
51+
params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
5052
params.n_gpu_layers = params.speculative.n_gpu_layers;
53+
5154
if (params.speculative.cpuparams.n_threads > 0) {
5255
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
5356
}
@@ -66,8 +69,14 @@ int main(int argc, char ** argv) {
6669
std::vector<llama_token> inp;
6770
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
6871

69-
if ((int) inp.size() > llama_n_ctx(ctx_tgt)) {
70-
LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
72+
if (llama_n_ctx(ctx_tgt) < (int) inp.size()) {
73+
LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
74+
75+
return 1;
76+
}
77+
78+
if (llama_n_batch(ctx_tgt) < (int) inp.size()) {
79+
LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt));
7180

7281
return 1;
7382
}
@@ -114,7 +123,7 @@ int main(int argc, char ** argv) {
114123
// init the speculator
115124
struct common_speculative_params params_spec;
116125
params_spec.n_draft = n_draft;
117-
params_spec.n_reuse = 256;
126+
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
118127
params_spec.p_min = p_min;
119128

120129
struct common_speculative * spec = common_speculative_init(ctx_dft);

0 commit comments

Comments
 (0)