Skip to content

Commit 0d03605

Browse files
committed
cont : fix rerank
ggml-ci
1 parent 079d330 commit 0d03605

File tree

4 files changed

+41
-44
lines changed

4 files changed

+41
-44
lines changed

common/arg.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -988,10 +988,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
988988
params.tensor_buft_overrides.push_back({nullptr, nullptr});
989989
}
990990

991-
if (params.reranking && params.embedding) {
992-
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
993-
}
994-
995991
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
996992
throw std::runtime_error(string_format(
997993
"error: the supplied chat template is not supported: %s%s\n",
@@ -2747,9 +2743,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
27472743
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
27482744
add_opt(common_arg(
27492745
{"--reranking", "--rerank"},
2750-
string_format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
2746+
string_format("enable reranking endpoint on server (default: %s)", "disabled"),
27512747
[](common_params & params) {
2752-
params.reranking = true;
2748+
params.embedding = true;
2749+
params.pooling_type = LLAMA_POOLING_TYPE_RANK;
27532750
}
27542751
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
27552752
add_opt(common_arg(

common/common.cpp

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -897,34 +897,6 @@ struct common_init_result common_init_from_params(common_params & params) {
897897

898898
const llama_vocab * vocab = llama_model_get_vocab(model);
899899

900-
if (params.reranking) {
901-
bool ok = true;
902-
903-
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
904-
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
905-
ok = false;
906-
}
907-
908-
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
909-
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
910-
911-
if (!has_eos && !has_sep) {
912-
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
913-
ok = false;
914-
} else if (!has_eos) {
915-
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
916-
} else if (!has_sep) {
917-
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
918-
ok = false;
919-
}
920-
921-
if (!ok) {
922-
llama_model_free(model);
923-
924-
return iparams;
925-
}
926-
}
927-
928900
auto cparams = common_context_params_to_llama(params);
929901

930902
llama_context * lctx = llama_init_from_model(model, cparams);
@@ -966,6 +938,35 @@ struct common_init_result common_init_from_params(common_params & params) {
966938
}
967939
}
968940

941+
if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) {
942+
bool ok = true;
943+
944+
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
945+
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
946+
ok = false;
947+
}
948+
949+
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
950+
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
951+
952+
if (!has_eos && !has_sep) {
953+
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
954+
ok = false;
955+
} else if (!has_eos) {
956+
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
957+
} else if (!has_sep) {
958+
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
959+
ok = false;
960+
}
961+
962+
if (!ok) {
963+
llama_free(lctx);
964+
llama_model_free(model);
965+
966+
return iparams;
967+
}
968+
}
969+
969970
// load and optionally apply lora adapters
970971
for (auto & la : params.lora_adapters) {
971972
llama_adapter_lora_ptr lora;
@@ -1143,11 +1144,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11431144
cparams.op_offload = !params.no_op_offload;
11441145
cparams.swa_full = params.swa_full;
11451146

1146-
if (params.reranking) {
1147-
cparams.embeddings = true;
1148-
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
1149-
}
1150-
11511147
cparams.type_k = params.cache_type_k;
11521148
cparams.type_v = params.cache_type_v;
11531149

common/common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,6 @@ struct common_params {
355355
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
356356
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
357357
std::string embd_sep = "\n"; // separator of embeddings
358-
bool reranking = false; // enable reranking support on server
359358

360359
// server params
361360
int32_t port = 8080; // server listens on this network port

tools/server/server.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3350,7 +3350,7 @@ struct server_context {
33503350
common_set_adapter_lora(ctx, slot_batched->lora);
33513351
}
33523352

3353-
const bool do_encode = (params_base.embedding || params_base.reranking);
3353+
const bool do_encode = params_base.embedding;
33543354

33553355
// pad the batch so that batch.n_tokens >= n_slots
33563356
// TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
@@ -4567,13 +4567,18 @@ int main(int argc, char ** argv) {
45674567
};
45684568

45694569
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
4570-
const json body = json::parse(req.body);
4570+
if (!ctx_server.params_base.embedding) {
4571+
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
4572+
return;
4573+
}
45714574

45724575
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
45734576
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
45744577
return;
45754578
}
45764579

4580+
const json body = json::parse(req.body);
4581+
45774582
// for the shape of input/content, see tokenize_input_prompts()
45784583
json prompt;
45794584
if (body.count("input") != 0) {
@@ -4663,8 +4668,8 @@ int main(int argc, char ** argv) {
46634668
};
46644669

46654670
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
4666-
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
4667-
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
4671+
if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
4672+
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
46684673
return;
46694674
}
46704675

0 commit comments

Comments
 (0)