Skip to content

Commit f0ffd81

Browse files
committed
adapt common
1 parent a1b1dea commit f0ffd81

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

common/common.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,7 +1047,8 @@ struct common_init_result common_init_from_params(common_params & params) {
10471047
}
10481048

10491049
if (llama_model_has_encoder(model)) {
1050-
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
1050+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0));
1051+
llama_encode_ext(lctx, batch.get());
10511052
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
10521053
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
10531054
decoder_start_token_id = bos;
@@ -1056,7 +1057,8 @@ struct common_init_result common_init_from_params(common_params & params) {
10561057
tmp.push_back(decoder_start_token_id);
10571058
}
10581059
if (llama_model_has_decoder(model)) {
1059-
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
1060+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
1061+
llama_encode_ext(lctx, batch.get());
10601062
}
10611063
llama_kv_cache_clear(lctx);
10621064
llama_synchronize(lctx);

common/speculative.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct common_speculative {
1313
struct llama_context * ctx;
1414
struct common_sampler * smpl;
1515

16-
llama_batch batch;
16+
llama_batch_ext_ptr batch;
1717
llama_tokens prompt;
1818
};
1919

@@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init(
2222
auto * result = new common_speculative {
2323
/* .ctx = */ ctx_dft,
2424
/* .smpl = */ nullptr,
25-
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
25+
/* .batch = */ llama_batch_ext_ptr(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)),
2626
/* .prompt = */ {},
2727
};
2828

@@ -68,8 +68,6 @@ void common_speculative_free(struct common_speculative * spec) {
6868

6969
common_sampler_free(spec->smpl);
7070

71-
llama_batch_free(spec->batch);
72-
7371
delete spec;
7472
}
7573

@@ -150,6 +148,8 @@ llama_tokens common_speculative_gen_draft(
150148

151149
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
152150

151+
const llama_seq_id seq_id = 0;
152+
153153
// reuse as much as possible from the old draft context
154154
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
155155
for (int i = 0; i < (int) prompt.size(); ++i) {
@@ -205,40 +205,40 @@ llama_tokens common_speculative_gen_draft(
205205
}
206206

207207
// prepare a batch to evaluate any new tokens in the prompt
208-
common_batch_clear(batch);
208+
llama_batch_ext_clear(batch.get());
209209

210210
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
211211
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
212-
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
212+
llama_batch_ext_add_text_token(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false);
213213

214214
prompt.push_back(prompt_tgt[i]);
215215
}
216216

217217
// we should rarely end-up here during normal decoding
218-
if (batch.n_tokens > 0) {
218+
if (llama_batch_ext_get_n_tokens(batch.get()) > 0) {
219219
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
220220

221-
llama_decode(ctx, batch);
221+
llama_decode_ext(ctx, batch.get());
222222
}
223223

224224
const llama_pos n_past = prompt.size();
225225

226226
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
227227

228-
common_batch_clear(batch);
229-
common_batch_add (batch, id_last, n_past, { 0 }, true);
228+
llama_batch_ext_clear(batch.get());
229+
llama_batch_ext_add_text_token(batch.get(), id_last, n_past, &seq_id, 1, true);
230230

231231
prompt.push_back(id_last);
232232

233233
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
234234

235-
llama_decode(ctx, batch);
235+
llama_decode_ext(ctx, batch.get());
236236

237237
common_sampler_reset(smpl);
238238

239239
// sample n_draft tokens from the draft model
240240
for (int i = 0; i < params.n_draft; ++i) {
241-
common_batch_clear(batch);
241+
llama_batch_ext_clear(batch.get());
242242

243243
common_sampler_sample(smpl, ctx, 0, true);
244244

@@ -265,10 +265,10 @@ llama_tokens common_speculative_gen_draft(
265265
break;
266266
}
267267

268-
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
268+
llama_batch_ext_add_text_token(batch.get(), id, n_past + i + 1, &seq_id, 1, true);
269269

270270
// evaluate the drafted tokens on the draft model
271-
llama_decode(ctx, batch);
271+
llama_decode_ext(ctx, batch.get());
272272

273273
prompt.push_back(id);
274274
}

0 commit comments

Comments
 (0)