diff --git a/examples/cross_encoder/applications/cross-encoder_instruct_reranking.py b/examples/cross_encoder/applications/cross-encoder_instruct_reranking.py
new file mode 100644
index 000000000..8488aef59
--- /dev/null
+++ b/examples/cross_encoder/applications/cross-encoder_instruct_reranking.py
@@ -0,0 +1,135 @@
+"""
+This example demonstrates how to use the CrossEncoder with instruction-tuned models like Qwen-reranker or BGE-reranker.
+The new `prompt_template` and `prompt_template_kwargs` arguments in the `predict` and `rank` methods allow for
+flexible and dynamic formatting of the input for such models.
+
+This script covers three main scenarios:
+1. Ranking without any template (baseline).
+2. Ranking with a `prompt_template` provided at runtime.
+3. Ranking with a dynamic `instruction` passed via `prompt_template_kwargs`.
+
+Finally, it provides a guide on how to set a default prompt template in the model's `config.json`.
+"""
+
+from sentence_transformers.cross_encoder import CrossEncoder
+
+# We use a Qwen Reranker model here. In a real-world scenario, this could also be
+# an instruction-tuned model like 'BAAI/bge-reranker-large'.
+model = CrossEncoder("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", trust_remote_code=True)
+model.model.config.pad_token_id = model.tokenizer.pad_token_id
+
+query = "What is the capital of China?"
+documents = [
+ "The capital of China is Beijing.",
+ "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
+]
+
+# First, we create the sentence pairs for the query and all documents
+sentence_pairs = [[query, doc] for doc in documents]
+
+print("--- 1. Reranking without any template (Incorrect Usage of Qwen3 Reranker) ---")
+# The model receives the plain query and document pairs.
+baseline_scores = model.predict(sentence_pairs, convert_to_numpy=True)
+scored_docs = sorted(zip(baseline_scores, documents), key=lambda x: x[0], reverse=True)
+
+print("Query:", query)
+for score, doc in scored_docs:
+ print(f"{score:.4f}\t{doc}")
+
+print("\n\n--- 2. Reranking with a runtime prompt_template ---")
+# The query and document are formatted using the template before being passed to the model.
+# This changes the input text and thus the resulting scores.
+prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
+suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
+instruction = "Given a web search query, retrieve relevant passages that answer the query"
+query_template = f"{prefix}: {instruction}\n: {{query}}\n"
+document_template = f": {{document}}{suffix}"
+
+template = query_template + document_template
+template_scores = model.predict(sentence_pairs, prompt_template=template)
+scored_docs_template = sorted(zip(template_scores, documents), key=lambda x: x[0], reverse=True)
+
+print("Using template:", template)
+print("Query:", query)
+for score, doc in scored_docs_template:
+ print(f"{score:.4f}\t{doc}")
+# The scores will be different from the baseline because the model processes a different text.
+
+print("\n\n--- 3. Reranking with a dynamic instruction ---")
+# This is useful for models that expect a specific instruction.
+# The instruction can be changed at runtime.
+prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
+suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
+instruct_template = f"{prefix}: {{instruction}}\n: {{query}}\n: {{document}}{suffix}"
+instruct_kwargs_1 = {"instruction": "Given a query, find the most relevant document."}
+instruct_kwargs_2 = {"instruction": "Given a question, find the incorrect answer."} # Misleading instruction
+
+print(f"Using template: {instruct_template}")
+print(f"With instruction 1: '{instruct_kwargs_1['instruction']}'")
+instruction_scores_1 = model.predict(
+ sentence_pairs, prompt_template=instruct_template, prompt_template_kwargs=instruct_kwargs_1
+)
+scored_docs_instruct_1 = sorted(zip(instruction_scores_1, documents), key=lambda x: x[0], reverse=True)
+for score, doc in scored_docs_instruct_1:
+ print(f"{score:.4f}\t{doc}")
+
+print(f"\nWith instruction 2: '{instruct_kwargs_2['instruction']}'")
+instruction_scores_2 = model.predict(
+ sentence_pairs, prompt_template=instruct_template, prompt_template_kwargs=instruct_kwargs_2
+)
+scored_docs_instruct_2 = sorted(zip(instruction_scores_2, documents), key=lambda x: x[0], reverse=True)
+for score, doc in scored_docs_instruct_2:
+ print(f"{score:.4f}\t{doc}")
+# The scores for instruction 1 and 2 will likely differ, as the instruction text changes the input.
+
+# --- 4. Guide: Setting a Default Prompt Template in config.json ---
+#
+# If you are a model creator or want to use a specific prompt format consistently
+# without passing it in every `rank` or `predict` call, you can set a default
+# template in the model's `config.json` file.
+#
+# Step 1: Save your base model to a directory.
+#
+# from sentence_transformers import CrossEncoder
+# import json
+#
+# model = CrossEncoder("your-base-model-name")
+# save_path = "path/to/your-instruct-model"
+# model.save(save_path)
+#
+# Step 2: Modify the `config.json` in the saved directory.
+# Add the "prompt_template" and "prompt_template_kwargs" keys to the
+# "sentence_transformers" dictionary.
+#
+# // path/to/your-instruct-model/config.json
+# {
+# ...
+# "sentence_transformers": {
+# "version": "3.0.0.dev0",
+# "prompt_template": "Instruct: {instruction}\nQuery: {query}\nDocument: {document}",
+# "prompt_template_kwargs": {
+# "instruction": "Given a query, find the most relevant document."
+# }
+# },
+# ...
+# }
+#
+# Step 3: Load the model from the modified path.
+# It will now use the default template automatically.
+#
+# instruct_model = CrossEncoder(save_path, trust_remote_code=True)
+# sentence_pairs = [[query, doc] for doc in documents]
+# scores = instruct_model.predict(sentence_pairs)
+#
+# # This call is now equivalent to calling the original model with the full template arguments:
+# # original_model.predict(sentence_pairs,
+# # prompt_template="Instruct: {instruction}\nQuery: {query}\nDocument: {document}",
+# # prompt_template_kwargs={"instruction": "Given a query, find the most relevant document."})
+#
+# You can still override the default template by passing arguments at runtime:
+#
+# # This will use the new instruction, overriding the default one.
+# scores_new_instruction = instruct_model.predict(
+# sentence_pairs,
+# prompt_template_kwargs={"instruction": "Find the answer to the question."}
+# )
diff --git a/sentence_transformers/cross_encoder/CrossEncoder.py b/sentence_transformers/cross_encoder/CrossEncoder.py
index 07b7bfc93..f10de22d3 100644
--- a/sentence_transformers/cross_encoder/CrossEncoder.py
+++ b/sentence_transformers/cross_encoder/CrossEncoder.py
@@ -170,6 +170,12 @@ def __init__(
if config.architectures is not None:
classifier_trained = any([arch.endswith("ForSequenceClassification") for arch in config.architectures])
+ self.default_prompt_template: str | None = None
+ self.default_prompt_template_kwargs: dict[str, Any] | None = None
+ if hasattr(config, "sentence_transformers"):
+ self.default_prompt_template = config.sentence_transformers.get("prompt_template")
+ self.default_prompt_template_kwargs = config.sentence_transformers.get("prompt_template_kwargs")
+
if num_labels is None and not classifier_trained:
num_labels = 1
@@ -538,6 +544,8 @@ def predict(
apply_softmax: bool | None = ...,
convert_to_numpy: Literal[False] = ...,
convert_to_tensor: Literal[False] = ...,
+ prompt_template: str | None = None,
+ prompt_template_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor: ...
@overload
@@ -550,6 +558,8 @@ def predict(
apply_softmax: bool | None = ...,
convert_to_numpy: Literal[True] = True,
convert_to_tensor: Literal[False] = False,
+ prompt_template: str | None = None,
+ prompt_template_kwargs: dict[str, Any] | None = None,
) -> np.ndarray: ...
@overload
@@ -562,6 +572,8 @@ def predict(
apply_softmax: bool | None = ...,
convert_to_numpy: bool = ...,
convert_to_tensor: Literal[True] = ...,
+ prompt_template: str | None = None,
+ prompt_template_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor: ...
@overload
@@ -574,6 +586,8 @@ def predict(
apply_softmax: bool | None = ...,
convert_to_numpy: Literal[False] = ...,
convert_to_tensor: Literal[False] = ...,
+ prompt_template: str | None = None,
+ prompt_template_kwargs: dict[str, Any] | None = None,
) -> list[torch.Tensor]: ...
@torch.inference_mode()
@@ -587,6 +601,8 @@ def predict(
apply_softmax: bool | None = False,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
+ prompt_template: str | None = None,
+ prompt_template_kwargs: dict[str, Any] | None = None,
) -> list[torch.Tensor] | np.ndarray | torch.Tensor:
"""
Performs predictions with the CrossEncoder on the given sentence pairs.
@@ -599,13 +615,17 @@ def predict(
activation_fn (callable, optional): Activation function applied on the logits output of the CrossEncoder.
If None, the ``model.activation_fn`` will be used, which defaults to :class:`torch.nn.Sigmoid` if num_labels=1, else
:class:`torch.nn.Identity`. Defaults to None.
- convert_to_numpy (bool, optional): Convert the output to a numpy matrix. Defaults to True.
apply_softmax (bool, optional): If set to True and `model.num_labels > 1`, applies softmax on the logits
output such that for each sample, the scores of each class sum to 1. Defaults to False.
convert_to_numpy (bool, optional): Whether the output should be a list of numpy vectors. If False, output
a list of PyTorch tensors. Defaults to True.
convert_to_tensor (bool, optional): Whether the output should be one large tensor. Overwrites `convert_to_numpy`.
Defaults to False.
+ prompt_template (str, optional): A template to format the input sentence pairs. The template should have placeholders
+ for `{query}` and `{document}`. For example: "Query: {query} Document: {document}".
+ prompt_template_kwargs (dict[str, Any], optional): A dictionary of keyword arguments to format the prompt template.
+ For example, you can provide an instruction: `{"instruction": "Determine the relevance."}` for a template like
+ "Instruct: {instruction} Query: {query} Document: {document}".
Returns:
Union[List[torch.Tensor], np.ndarray, torch.Tensor]: Predictions for the passed sentence pairs.
@@ -634,6 +654,28 @@ def predict(
logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG
)
+ # If prompt_template is provided or default_prompt_template in config.json is set, use it to format the input sentence pairs
+ final_prompt_template = prompt_template if prompt_template is not None else self.default_prompt_template
+ if final_prompt_template:
+ final_kwargs = self.default_prompt_template_kwargs.copy() if self.default_prompt_template_kwargs else {}
+ if prompt_template_kwargs:
+ # Update final_kwargs with any additional keyword arguments provided in prompt_template_kwargs
+ final_kwargs.update(prompt_template_kwargs)
+
+ formatted_sentences = []
+ try:
+ for query, doc in sentences:
+ all_kwargs = {"query": query, "document": doc, **final_kwargs}
+ formatted_sentences.append(final_prompt_template.format(**all_kwargs))
+ except KeyError as e:
+ # If a placeholder in the prompt template is not valid, raise an error
+ available_keys = ["query", "document"] + list(final_kwargs.keys())
+ raise KeyError(
+ f"A placeholder in the prompt template is not valid. The placeholder {e} was not found. "
+ f"Available placeholders are: {', '.join(sorted(list(set(available_keys))))}."
+ ) from e
+ sentences = formatted_sentences
+
if activation_fn is not None:
self.set_activation_fn(activation_fn, set_default=False)
@@ -681,6 +723,8 @@ def rank(
apply_softmax=False,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
+ prompt_template: str | None = None,
+ prompt_template_kwargs: dict[str, Any] | None = None,
) -> list[dict[Literal["corpus_id", "score", "text"], int | float | str]]:
"""
Performs ranking with the CrossEncoder on the given query and documents. Returns a sorted list with the document indices and scores.
@@ -696,6 +740,11 @@ def rank(
convert_to_numpy (bool, optional): Convert the output to a numpy matrix. Defaults to True.
apply_softmax (bool, optional): If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output. Defaults to False.
convert_to_tensor (bool, optional): Convert the output to a tensor. Defaults to False.
+ prompt_template (str, optional): A template to format the input sentence pairs. The template should have placeholders
+ for `{query}` and `{document}`. For example: "Query: {query} Document: {document}".
+ prompt_template_kwargs (dict[str, Any], optional): A dictionary of keyword arguments to format the prompt template.
+ For example, you can provide an instruction: `{"instruction": "Determine the relevance."}` for a template like
+ "Instruct: {instruction} Query: {query} Document: {document}".
Returns:
List[Dict[Literal["corpus_id", "score", "text"], Union[int, float, str]]]: A sorted list with the "corpus_id", "score", and optionally "text" of the documents.
@@ -750,6 +799,8 @@ def rank(
apply_softmax=apply_softmax,
convert_to_numpy=convert_to_numpy,
convert_to_tensor=convert_to_tensor,
+ prompt_template=prompt_template,
+ prompt_template_kwargs=prompt_template_kwargs,
)
results = []
diff --git a/tests/cross_encoder/test_cross_encoder.py b/tests/cross_encoder/test_cross_encoder.py
index f8d5969c5..b7135f628 100644
--- a/tests/cross_encoder/test_cross_encoder.py
+++ b/tests/cross_encoder/test_cross_encoder.py
@@ -132,6 +132,81 @@ def test_predict_softmax():
assert not torch.isclose(scores.sum(1), torch.ones(len(corpus), device=scores.device)).all()
+def test_predict_with_prompt_template():
+ model = CrossEncoder("cross-encoder-testing/reranker-bert-tiny-gooaq-bce")
+ query = "A man is eating pasta."
+ corpus = [
+ "A man is eating food.",
+ "A woman is playing violin.",
+ ]
+ pairs = [[query, doc] for doc in corpus]
+ prompt_template = (
+ "Instruct: Given a query and a document, determine if they are relevant. Query: {query} Document: {document}"
+ )
+
+ # 1. Test with prompt_template in predict
+ scores_prompt = model.predict(pairs, prompt_template=prompt_template)
+
+ # 2. Test without prompt_template
+ scores_no_prompt = model.predict(pairs)
+
+ # The scores should be different as the input to the model is different
+ assert not np.allclose(scores_prompt, scores_no_prompt)
+
+ # 3. Test with prompt_template in rank
+ ranks_prompt = model.rank(query, corpus, prompt_template=prompt_template)
+ ranks_no_prompt = model.rank(query, corpus)
+
+ assert ranks_prompt[0]["score"] != ranks_no_prompt[0]["score"]
+ assert ranks_prompt[1]["score"] != ranks_no_prompt[1]["score"]
+
+
+def test_predict_with_default_prompt_template(tmp_path: Path):
+ # 1. Create a base model and save it
+ model_name = "cross-encoder-testing/reranker-bert-tiny-gooaq-bce"
+ original_model = CrossEncoder(model_name)
+ save_path = tmp_path / "model_with_template"
+ original_model.save(str(save_path))
+
+ # 2. Modify the config.json to add a default prompt template
+ config_path = save_path / "config.json"
+ with open(config_path) as f:
+ config = json.load(f)
+
+ prompt_template = "Instruct: {instruction} Query: {query} Document: {document}"
+ prompt_kwargs = {"instruction": "Determine relevance."}
+ if "sentence_transformers" not in config:
+ config["sentence_transformers"] = {}
+ config["sentence_transformers"]["prompt_template"] = prompt_template
+ config["sentence_transformers"]["prompt_template_kwargs"] = prompt_kwargs
+
+ with open(config_path, "w") as f:
+ json.dump(config, f)
+
+ # 3. Load the model with the modified config
+ model_with_template = CrossEncoder(str(save_path))
+ assert model_with_template.default_prompt_template == prompt_template
+ assert model_with_template.default_prompt_template_kwargs == prompt_kwargs
+
+ # 4. Perform prediction and compare results
+ query = "A man is eating pasta."
+ doc = "A man is eating food."
+
+ # Prediction with the model that has a default template
+ scores_with_default_template = model_with_template.predict([[query, doc]])
+
+ # Prediction with the original model (no template)
+ scores_original = original_model.predict([[query, doc]])
+
+ # The scores should be different because one uses the template and the other doesn't.
+ assert not np.allclose(scores_with_default_template, scores_original)
+
+ # 5. Test that runtime arguments can overwrite the default template
+ runtime_template = "Query: {query} and Doc: {document}"
+ scores_runtime_template = model_with_template.predict([[query, doc]], prompt_template=runtime_template)
+ assert not np.allclose(scores_with_default_template, scores_runtime_template)
+
+
@pytest.mark.parametrize(
"model_name", ["cross-encoder-testing/reranker-bert-tiny-gooaq-bce", "cross-encoder/nli-MiniLM2-L6-H768"]
)