Skip to content

Commit c2f4dc7

Browse files
committed
fix prompt
1 parent c02f53d commit c2f4dc7

File tree

5 files changed

+46
-47
lines changed

5 files changed

+46
-47
lines changed

common/common.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -905,10 +905,9 @@ struct common_init_result common_init_from_params(common_params & params) {
905905
ok = false;
906906
}
907907

908-
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
909-
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
910-
bool has_rerank_prompt = llama_model_chat_template(model, "rerank_prefix") != NULL ||
911-
llama_model_chat_template(model, "rerank_suffix") != NULL;
908+
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
909+
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
910+
bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;
912911

913912
if (has_rerank_prompt) {
914913
// OK, do nothing

convert_hf_to_gguf.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3099,11 +3099,10 @@ def set_gguf_parameters(self):
30993099
self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
31003100
self.gguf_writer.add_classifier_output_labels(["yes", "no"])
31013101
self.gguf_writer.add_chat_template([{
3102-
"name": "rerank_prefix",
3103-
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n",
3104-
}, {
3105-
"name": "rerank_suffix",
3106-
"template": "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
3102+
"name": "rerank",
3103+
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
3104+
+ "<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {document}\n"
3105+
+ "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
31073106
}])
31083107

31093108
def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:

src/llama-graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1585,7 +1585,7 @@ void llm_graph_context::build_pooling(
15851585
} else if (cls_out) {
15861586
if (arch == LLM_ARCH_QWEN3) {
15871587
cur = ggml_mul_mat(ctx0, cls_out, inp);
1588-
cur = ggml_soft_max(ctx0, cur); // qwen3 uses softmax on the output
1588+
cur = ggml_log(ctx0, ggml_soft_max(ctx0, cur)); // qwen3 uses log_softmax
15891589
} else {
15901590
// Single layer classification head (direct projection)
15911591
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476

tools/server/server.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4704,22 +4704,19 @@ int main(int argc, char ** argv) {
47044704
return;
47054705
}
47064706

4707-
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0];
4708-
47094707
// create and queue the task
47104708
json responses = json::array();
47114709
bool error = false;
47124710
std::unordered_set<int> task_ids;
47134711
{
47144712
std::vector<server_task> tasks;
4715-
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
4716-
tasks.reserve(tokenized_docs.size());
4717-
for (size_t i = 0; i < tokenized_docs.size(); i++) {
4718-
auto tmp = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
4713+
auto inputs = tokenize_rerank(ctx_server.model, query, documents);
4714+
tasks.reserve(documents.size());
4715+
for (size_t i = 0; i < inputs.size(); i++) {
47194716
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
47204717
task.id = ctx_server.queue_tasks.get_new_id();
47214718
task.index = i;
4722-
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
4719+
task.prompt_tokens = server_tokens(inputs[i], ctx_server.mctx != nullptr);
47234720
tasks.push_back(std::move(task));
47244721
}
47254722

tools/server/utils.hpp

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -260,43 +260,47 @@ static size_t validate_utf8(const std::string& text) {
260260
// template utils
261261
//
262262

263-
// format rerank task:
263+
// format and tokenize rerank task:
264264
// - using SEP token: [BOS]query[EOS][SEP]doc[EOS]
265265
// - using prompt: <rerank_prefix>query<rerank_suffix>doc
266-
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
266+
static std::vector<llama_tokens> tokenize_rerank(const struct llama_model * model, const std::string & query, const std::vector<std::string> & documents) {
267267
const llama_vocab * vocab = llama_model_get_vocab(model);
268-
llama_tokens result;
268+
std::vector<llama_tokens> result;
269269

270-
if (llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL) {
271-
// Get EOS token - use SEP token as fallback if EOS is not available
272-
llama_token eos_token = llama_vocab_eos(vocab);
273-
if (eos_token == LLAMA_TOKEN_NULL) {
274-
eos_token = llama_vocab_sep(vocab);
275-
}
270+
for (const auto & doc : documents) {
271+
if (llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL) {
272+
// Get EOS token - use SEP token as fallback if EOS is not available
273+
llama_tokens tok;
274+
llama_tokens tok_query = common_tokenize(vocab, query, false, false);
275+
llama_tokens tok_doc = common_tokenize(vocab, doc, false, false);
276+
llama_token eos_token = llama_vocab_eos(vocab);
277+
if (eos_token == LLAMA_TOKEN_NULL) {
278+
eos_token = llama_vocab_sep(vocab);
279+
}
276280

277-
result.reserve(doc.size() + query.size() + 4);
278-
result.push_back(llama_vocab_bos(vocab));
279-
result.insert(result.end(), query.begin(), query.end());
280-
result.push_back(eos_token);
281-
result.push_back(llama_vocab_sep(vocab));
282-
result.insert(result.end(), doc.begin(), doc.end());
283-
result.push_back(eos_token);
284-
} else {
285-
// using prompt template
286-
const char * prefix = llama_model_chat_template(model, "rerank_prefix");
287-
const char * suffix = llama_model_chat_template(model, "rerank_suffix");
281+
tok.reserve(doc.size() + query.size() + 4);
282+
tok.push_back(llama_vocab_bos(vocab));
283+
tok.insert(tok.end(), tok_query.begin(), tok_query.end());
284+
tok.push_back(eos_token);
285+
tok.push_back(llama_vocab_sep(vocab));
286+
tok.insert(tok.end(), tok_doc.begin(), tok_doc.end());
287+
tok.push_back(eos_token);
288288

289-
if (prefix == NULL && suffix == NULL) {
290-
throw std::runtime_error("Rerank prompt template not found in the model\n");
291-
}
289+
result.push_back(std::move(tok));
290+
} else {
291+
// using prompt template
292+
const char * tmpl = llama_model_chat_template(model, "rerank");
293+
if (tmpl == nullptr) {
294+
throw std::runtime_error("model does not have rerank template");
295+
}
292296

293-
const llama_tokens prefix_tokens = prefix ? common_tokenize(vocab, prefix, true, false) : llama_tokens();
294-
const llama_tokens suffix_tokens = suffix ? common_tokenize(vocab, suffix, false, false) : llama_tokens();
295-
result.reserve(prefix_tokens.size() + query.size() + suffix_tokens.size() + doc.size());
296-
result.insert(result.end(), prefix_tokens.begin(), prefix_tokens.end());
297-
result.insert(result.end(), query.begin(), query.end());
298-
result.insert(result.end(), suffix_tokens.begin(), suffix_tokens.end());
299-
result.insert(result.end(), doc.begin(), doc.end());
297+
std::string prompt = tmpl;
298+
// TODO: may not be efficient to call string_replace_all twice
299+
string_replace_all(prompt, "{query}", query);
300+
string_replace_all(prompt, "{document}", doc);
301+
llama_tokens tok = common_tokenize(vocab, prompt, true, false);
302+
result.push_back(std::move(tok));
303+
}
300304
}
301305

302306
return result;

0 commit comments

Comments
 (0)