Skip to content

Commit e8ddfa3

Browse files
committed
llama : rework embeddings logic
ggml-ci
1 parent d7da8dc commit e8ddfa3

13 files changed

+90
-59
lines changed

examples/gritlm/gritlm.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4141

4242
// add input to batch (this increments n_tokens)
4343
for (int32_t j = 0; j < n_toks; j++) {
44-
common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
44+
common_batch_add(batch, inputs[j], j, { 0 }, true);
4545
}
4646

4747
// clear previous kv_cache values (irrelevant for embeddings)
4848
llama_memory_clear(llama_get_memory(ctx), true);
49-
llama_set_embeddings(ctx, true);
5049
llama_set_causal_attn(ctx, false);
5150

5251
// run model
@@ -103,7 +102,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
103102
llama_token eos_token = llama_vocab_eos(vocab);
104103

105104
llama_memory_clear(llama_get_memory(ctx), true);
106-
llama_set_embeddings(ctx, false);
107105
llama_set_causal_attn(ctx, true);
108106

109107
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
@@ -166,6 +164,8 @@ int main(int argc, char * argv[]) {
166164
llama_model_params mparams = common_model_params_to_llama(params);
167165
llama_context_params cparams = common_context_params_to_llama(params);
168166

167+
cparams.embeddings = true;
168+
169169
llama_backend_init();
170170

171171
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
@@ -213,6 +213,8 @@ int main(int argc, char * argv[]) {
213213
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
214214
}
215215

216+
llama_set_embeddings(ctx, false);
217+
216218
// ### Generation ###
217219
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
218220
{

include/llama.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,16 +254,19 @@ extern "C" {
254254
// - seq_id : the sequence to which the respective token belongs
255255
// (if set to NULL, the sequence ID will be assumed to be 0)
256256
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
257-
// (if set to NULL, only the logits for last token will be returned)
257+
// (if set to NULL:
258+
// - if embeddings: all tokens are output
259+
// - if not: only the last token is output
260+
// )
258261
//
259262
typedef struct llama_batch {
260263
int32_t n_tokens;
261264

262265
llama_token * token;
263266
float * embd;
264267
llama_pos * pos;
265-
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
266-
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
268+
int32_t * n_seq_id;
269+
llama_seq_id ** seq_id;
267270
int8_t * logits; // TODO: rename this to "output"
268271
} llama_batch;
269272

@@ -961,8 +964,8 @@ extern "C" {
961964
// Get the number of threads used for prompt and batch processing (multiple token).
962965
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
963966

964-
// Set whether the model is in embeddings mode or not
965-
// If true, embeddings will be returned but logits will not
967+
// Set whether the context outputs embeddings or not
968+
// Note: set to true only if the context was created with llama_context_params.embeddings = true
966969
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
967970

968971
// Set whether to use causal attention or not

src/llama-batch.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ llama_batch_allocr::llama_batch_allocr() {
299299
bool llama_batch_allocr::init(
300300
const llama_batch & batch_inp,
301301
const llama_vocab & vocab,
302-
const llama_memory_i * memory) {
302+
const llama_memory_i * memory,
303+
bool embd_all) {
303304
clear();
304305

305306
batch = batch_inp;
@@ -378,10 +379,31 @@ bool llama_batch_allocr::init(
378379
}
379380

380381
if (!batch.logits) {
381-
// by default return the output only for the last token
382-
output.resize(batch.n_tokens);
383-
output[output.size() - 1] = true;
382+
if (embd_all) {
383+
// return the output for all tokens
384+
output.resize(batch.n_tokens, true);
385+
} else {
386+
// return the output only for the last token
387+
output.resize(batch.n_tokens, false);
388+
output[output.size() - 1] = true;
389+
}
390+
384391
batch.logits = output.data();
392+
} else if (embd_all) {
393+
bool warn = false;
394+
395+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
396+
if (batch.logits[i] == 0) {
397+
warn = true;
398+
}
399+
}
400+
401+
if (warn) {
402+
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
403+
404+
output.resize(batch.n_tokens, true);
405+
batch.logits = output.data();
406+
}
385407
}
386408

387409
//

src/llama-batch.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class llama_batch_allocr {
8888
bool init(
8989
const llama_batch & batch_inp,
9090
const llama_vocab & vocab,
91-
const llama_memory_i * memory);
91+
const llama_memory_i * memory,
92+
bool embd_all);
9293

9394
const llama_batch & get_batch() const;
9495

src/llama-context.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ llama_context::llama_context(
8181
}
8282
}
8383

84+
if (!cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
85+
LLAMA_LOG_WARN("%s: pooling_type is set to %d but embeddings is set to false - disabling pooling\n", __func__, cparams.pooling_type);
86+
87+
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
88+
}
89+
8490
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
8591
cparams.causal_attn = hparams.causal_attn;
8692
} else {
@@ -728,7 +734,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
728734
}
729735

730736
// note: during encode, we always pass the full sequence starting from pos = 0
731-
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
737+
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
732738
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
733739
return -1;
734740
}
@@ -894,7 +900,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
894900
return -1;
895901
}
896902

897-
if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
903+
// when computing embeddings, all tokens are output
904+
const bool embd_all = cparams.embeddings;
905+
906+
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
898907
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
899908
return -1;
900909
}
@@ -911,12 +920,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
911920

912921
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
913922

914-
// this indicates we are doing pooled embedding
915-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
916-
917923
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
918924

919-
if (embd_pooled) {
925+
if (embd_all) {
920926
// require that all tokens are output
921927
if (n_outputs_all != n_tokens_all) {
922928
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@@ -945,7 +951,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
945951
llama_memory_state_ptr mstate;
946952

947953
while (true) {
948-
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
954+
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
949955
if (!mstate) {
950956
return -2;
951957
}
@@ -1058,7 +1064,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10581064
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
10591065
//}
10601066

1061-
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
1067+
auto * t_logits = res->get_logits();
10621068
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
10631069

10641070
if (t_embd && res->get_embd_pooled()) {
@@ -1222,9 +1228,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
12221228
const auto n_vocab = vocab.n_tokens();
12231229
const auto n_embd = hparams.n_embd;
12241230

1225-
// TODO: use a per-batch flag for logits presence instead
1226-
bool has_logits = !cparams.embeddings;
1227-
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1231+
bool has_logits = true;
1232+
bool has_embd = cparams.embeddings;
12281233

12291234
// TODO: hacky enc-dec support
12301235
if (model.arch == LLM_ARCH_T5) {
@@ -2044,14 +2049,11 @@ void llama_context::opt_epoch_iter(
20442049

20452050
n_queued_tokens += n_tokens_all;
20462051

2047-
// this indicates we are doing pooled embedding
2048-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2049-
20502052
embd_seq.clear();
20512053

20522054
uint32_t n_outputs_all = n_tokens_all;
20532055

2054-
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
2056+
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
20552057
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
20562058
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
20572059
break;

src/llama-kv-cache-recurrent.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,18 +359,16 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
359359
return result;
360360
}
361361

362-
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
363-
GGML_UNUSED(embd_pooled);
364-
362+
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
365363
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
366364

367365
std::vector<llama_ubatch> ubatches;
368366

369367
while (sbatch.n_tokens > 0) {
370368
llama_ubatch ubatch;
371369

372-
if (embd_pooled) {
373-
// Pooled embeddings cannot be split across ubatches (yet)
370+
if (embd_all) {
371+
// if all tokens are output, split by sequence
374372
ubatch = sbatch.split_seq(n_ubatch);
375373
} else {
376374
ubatch = sbatch.split_equal(n_ubatch);

src/llama-kv-cache-recurrent.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class llama_kv_cache_recurrent : public llama_memory_i {
3232
llama_memory_state_ptr init_batch(
3333
const llama_batch & batch,
3434
uint32_t n_ubatch,
35-
bool embd_pooled) override;
35+
bool embd_all) override;
3636

3737
llama_memory_state_ptr init_full() override;
3838

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
9595
return kv_swa->seq_pos_max(seq_id);
9696
}
9797

98-
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
99-
GGML_UNUSED(embd_pooled);
98+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
99+
GGML_UNUSED(embd_all);
100100

101101
// first try simple split
102102
do {

src/llama-kv-cache-unified-iswa.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
3434
llama_memory_state_ptr init_batch(
3535
const llama_batch & batch,
3636
uint32_t n_ubatch,
37-
bool embd_pooled) override;
37+
bool embd_all) override;
3838

3939
llama_memory_state_ptr init_full() override;
4040

src/llama-kv-cache-unified.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
310310
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
311311
const llama_batch & batch,
312312
uint32_t n_ubatch,
313-
bool embd_pooled) {
314-
GGML_UNUSED(embd_pooled);
313+
bool embd_all) {
314+
GGML_UNUSED(embd_all);
315315

316316
do {
317317
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);

0 commit comments

Comments
 (0)