Skip to content

Commit 6f54ee6

Browse files
committed
retrieval : avoid common_batch
ggml-ci
1 parent 32c2c41 commit 6f54ee6

File tree

2 files changed

+24
-39
lines changed

2 files changed

+24
-39
lines changed

examples/retrieval/retrieval.cpp

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -74,55 +74,38 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
7474
return chunks;
7575
}
7676

77-
static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
78-
size_t n_tokens = tokens.size();
77+
static void batch_add_seq(llama_batch_ext * batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
78+
const size_t n_tokens = tokens.size();
7979
for (size_t i = 0; i < n_tokens; i++) {
80-
batch.add_text(tokens[i], i, seq_id, true);
80+
llama_batch_ext_add_text(batch, tokens[i], i, &seq_id, 1, true);
8181
}
8282
}
8383

84-
static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) {
85-
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
84+
static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm = 2) {
8685
const struct llama_model * model = llama_get_model(ctx);
8786

8887
// clear previous kv_cache values (irrelevant for embeddings)
8988
llama_kv_self_clear(ctx);
9089

9190
// run model
92-
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq);
91+
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch), n_seq);
9392
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
9493
// encoder-only model
95-
if (llama_encode_ext(ctx, batch.get()) < 0) {
94+
if (llama_encode_ext(ctx, batch) < 0) {
9695
LOG_ERR("%s : failed to encode\n", __func__);
9796
}
9897
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
9998
// decoder-only model
100-
if (llama_decode_ext(ctx, batch.get()) < 0) {
99+
if (llama_decode_ext(ctx, batch) < 0) {
101100
LOG_ERR("%s : failed to decode\n", __func__);
102101
}
103102
}
104103

105-
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) {
106-
if (!batch.tokens[i].logits) {
107-
continue;
108-
}
109-
110-
const float * embd = nullptr;
111-
int embd_pos = 0;
112-
113-
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
114-
// try to get token embeddings
115-
embd = llama_get_embeddings_ith(ctx, i);
116-
embd_pos = i;
117-
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
118-
} else {
119-
// try to get sequence embeddings - supported only when pooling_type is not NONE
120-
embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id);
121-
embd_pos = batch.tokens[i].seq_id;
122-
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
123-
}
104+
for (int s = 0; s < n_seq; s++) {
105+
const float * embd = llama_get_embeddings_seq(ctx, s);
106+
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
124107

125-
float * out = output + embd_pos * n_embd;
108+
float * out = output + s * n_embd;
126109
common_embd_normalize(embd, out, n_embd, embd_norm);
127110
}
128111
}
@@ -230,7 +213,7 @@ int main(int argc, char ** argv) {
230213

231214
// initialize batch
232215
const int n_chunks = chunks.size();
233-
struct common_batch batch = common_batch(n_batch, 1);
216+
llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
234217

235218
// allocate output
236219
const int n_embd = llama_model_n_embd(model);
@@ -247,10 +230,10 @@ int main(int argc, char ** argv) {
247230
const uint64_t n_toks = inp.size();
248231

249232
// encode if at capacity
250-
if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) {
251-
float * out = emb + p * n_embd;
252-
batch_decode(ctx, batch, out, s, n_embd);
253-
batch.clear();
233+
if (llama_batch_ext_get_n_tokens(batch) + n_toks > n_batch) {
234+
batch_decode(ctx, batch, emb + p * n_embd, s, n_embd);
235+
llama_batch_ext_clear(batch);
236+
254237
p += s;
255238
s = 0;
256239
}
@@ -261,8 +244,7 @@ int main(int argc, char ** argv) {
261244
}
262245

263246
// final batch
264-
float * out = emb + p * n_embd;
265-
batch_decode(ctx, batch, out, s, n_embd);
247+
batch_decode(ctx, batch, emb + p * n_embd, s, n_embd);
266248

267249
// save embeddings to chunks
268250
for (int i = 0; i < n_chunks; i++) {
@@ -271,7 +253,7 @@ int main(int argc, char ** argv) {
271253
chunks[i].tokens.clear();
272254
}
273255

274-
struct common_batch query_batch = common_batch(n_batch, 1);
256+
llama_batch_ext * query_batch = llama_batch_ext_init(n_batch, 1);
275257

276258
// start loop, receive query and return top k similar chunks based on cosine similarity
277259
std::string query;
@@ -285,7 +267,7 @@ int main(int argc, char ** argv) {
285267
std::vector<float> query_emb(n_embd, 0);
286268
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
287269

288-
query_batch.clear();
270+
llama_batch_ext_clear(query_batch);
289271

290272
// compute cosine similarities
291273
{
@@ -314,6 +296,9 @@ int main(int argc, char ** argv) {
314296
LOG("\n");
315297
llama_perf_context_print(ctx);
316298

299+
llama_batch_ext_free(batch);
300+
llama_batch_ext_free(query_batch);
301+
317302
// clean up
318303
llama_backend_free();
319304
}

include/llama.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -945,8 +945,8 @@ extern "C" {
945945
// The batch has to be freed with llama_batch_ext_free()
946946
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd(
947947
float * embd,
948-
size_t n_tokens,
949-
size_t n_embd,
948+
size_t n_tokens,
949+
size_t n_embd,
950950
int32_t pos0,
951951
int32_t seq_id);
952952

0 commit comments

Comments
 (0)