Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,10 +988,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}

if (params.reranking && params.embedding) {
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
}

if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s%s\n",
Expand Down Expand Up @@ -2747,9 +2743,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(common_arg(
{"--reranking", "--rerank"},
string_format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
string_format("enable reranking endpoint on server (default: %s)", "disabled"),
[](common_params & params) {
params.reranking = true;
params.embedding = true;
params.pooling_type = LLAMA_POOLING_TYPE_RANK;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
add_opt(common_arg(
Expand Down
62 changes: 29 additions & 33 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -897,34 +897,6 @@ struct common_init_result common_init_from_params(common_params & params) {

const llama_vocab * vocab = llama_model_get_vocab(model);

if (params.reranking) {
bool ok = true;

if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
ok = false;
}

bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;

if (!has_eos && !has_sep) {
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
ok = false;
} else if (!has_eos) {
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
} else if (!has_sep) {
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
ok = false;
}

if (!ok) {
llama_model_free(model);

return iparams;
}
}

auto cparams = common_context_params_to_llama(params);

llama_context * lctx = llama_init_from_model(model, cparams);
Expand Down Expand Up @@ -966,6 +938,35 @@ struct common_init_result common_init_from_params(common_params & params) {
}
}

if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) {
bool ok = true;

if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
ok = false;
}

bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;

if (!has_eos && !has_sep) {
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
ok = false;
} else if (!has_eos) {
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
} else if (!has_sep) {
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
ok = false;
}

if (!ok) {
llama_free(lctx);
llama_model_free(model);

return iparams;
}
}

// load and optionally apply lora adapters
for (auto & la : params.lora_adapters) {
llama_adapter_lora_ptr lora;
Expand Down Expand Up @@ -1143,11 +1144,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;

if (params.reranking) {
cparams.embeddings = true;
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
}

cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v;

Expand Down
1 change: 0 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,6 @@ struct common_params {
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embeddings
bool reranking = false; // enable reranking support on server

// server params
int32_t port = 8080; // server listens on this network port
Expand Down
8 changes: 5 additions & 3 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve

// add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) {
common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
common_batch_add(batch, inputs[j], j, { 0 }, true);
}

// clear previous kv_cache values (irrelevant for embeddings)
llama_memory_clear(llama_get_memory(ctx), true);
llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false);

// run model
Expand Down Expand Up @@ -103,7 +102,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_token eos_token = llama_vocab_eos(vocab);

llama_memory_clear(llama_get_memory(ctx), true);
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);

llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
Expand Down Expand Up @@ -166,6 +164,8 @@ int main(int argc, char * argv[]) {
llama_model_params mparams = common_model_params_to_llama(params);
llama_context_params cparams = common_context_params_to_llama(params);

cparams.embeddings = true;

llama_backend_init();

llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
Expand Down Expand Up @@ -213,6 +213,8 @@ int main(int argc, char * argv[]) {
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
}

llama_set_embeddings(ctx, false);

// ### Generation ###
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{
Expand Down
12 changes: 7 additions & 5 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,19 @@ extern "C" {
// - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
// (if set to NULL, only the logits for last token will be returned)
// (if set to NULL:
// - if embeddings: all tokens are output
// - if not: only the last token is output
// )
//
typedef struct llama_batch {
int32_t n_tokens;

llama_token * token;
float * embd;
llama_pos * pos;
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
} llama_batch;

Expand Down Expand Up @@ -961,8 +964,7 @@ extern "C" {
// Get the number of threads used for prompt and batch processing (multiple token).
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);

// Set whether the model is in embeddings mode or not
// If true, embeddings will be returned but logits will not
// Set whether the context outputs embeddings or not
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);

// Set whether to use causal attention or not
Expand Down
30 changes: 26 additions & 4 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ llama_batch_allocr::llama_batch_allocr() {
bool llama_batch_allocr::init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory) {
const llama_memory_i * memory,
bool embd_all) {
clear();

batch = batch_inp;
Expand Down Expand Up @@ -378,10 +379,31 @@ bool llama_batch_allocr::init(
}

if (!batch.logits) {
// by default return the output only for the last token
output.resize(batch.n_tokens);
output[output.size() - 1] = true;
if (embd_all) {
// return the output for all tokens
output.resize(batch.n_tokens, true);
} else {
// return the output only for the last token
output.resize(batch.n_tokens, false);
output[output.size() - 1] = true;
}

batch.logits = output.data();
} else if (embd_all) {
bool warn = false;

for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.logits[i] == 0) {
warn = true;
}
}

if (warn) {
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);

output.resize(batch.n_tokens, true);
batch.logits = output.data();
}
}

//
Expand Down
3 changes: 2 additions & 1 deletion src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class llama_batch_allocr {
bool init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory);
const llama_memory_i * memory,
bool embd_all);

const llama_batch & get_batch() const;

Expand Down
26 changes: 11 additions & 15 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
}

// note: during encode, we always pass the full sequence starting from pos = 0
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
Expand Down Expand Up @@ -894,7 +894,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
return -1;
}

if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
// when computing embeddings, all tokens are output
const bool embd_all = cparams.embeddings;

if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
Expand All @@ -911,12 +914,9 @@ int llama_context::decode(const llama_batch & batch_inp) {

GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT

// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

const uint32_t n_outputs_all = batch_allocr->get_n_outputs();

if (embd_pooled) {
if (embd_all) {
// require that all tokens are output
if (n_outputs_all != n_tokens_all) {
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
Expand Down Expand Up @@ -945,7 +945,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
llama_memory_state_ptr mstate;

while (true) {
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
if (!mstate) {
return -2;
}
Expand Down Expand Up @@ -1058,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}

auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;

if (t_embd && res->get_embd_pooled()) {
Expand Down Expand Up @@ -1222,9 +1222,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;

// TODO: use a per-batch flag for logits presence instead
bool has_logits = !cparams.embeddings;
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
bool has_logits = true;
bool has_embd = cparams.embeddings;

// TODO: hacky enc-dec support
if (model.arch == LLM_ARCH_T5) {
Expand Down Expand Up @@ -2044,14 +2043,11 @@ void llama_context::opt_epoch_iter(

n_queued_tokens += n_tokens_all;

// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

embd_seq.clear();

uint32_t n_outputs_all = n_tokens_all;

auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;
Expand Down
8 changes: 3 additions & 5 deletions src/llama-kv-cache-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,18 +359,16 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}

llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);

llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);

std::vector<llama_ubatch> ubatches;

while (sbatch.n_tokens > 0) {
llama_ubatch ubatch;

if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
if (embd_all) {
// if all tokens are output, split by sequence
ubatch = sbatch.split_seq(n_ubatch);
} else {
ubatch = sbatch.split_equal(n_ubatch);
Expand Down
2 changes: 1 addition & 1 deletion src/llama-kv-cache-recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class llama_kv_cache_recurrent : public llama_memory_i {
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
bool embd_all) override;

llama_memory_state_ptr init_full() override;

Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-unified-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id);
}

llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
GGML_UNUSED(embd_all);

// first try simple split
do {
Expand Down
2 changes: 1 addition & 1 deletion src/llama-kv-cache-unified-iswa.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
bool embd_all) override;

llama_memory_state_ptr init_full() override;

Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) {
GGML_UNUSED(embd_pooled);
bool embd_all) {
GGML_UNUSED(embd_all);

do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
Expand Down
2 changes: 1 addition & 1 deletion src/llama-kv-cache-unified.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class llama_kv_cache_unified : public llama_memory_i {
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
bool embd_all) override;

llama_memory_state_ptr init_full() override;

Expand Down
Loading
Loading