diff --git a/examples/cross_encoder/applications/cross-encoder_instruct_reranking.py b/examples/cross_encoder/applications/cross-encoder_instruct_reranking.py new file mode 100644 index 000000000..8488aef59 --- /dev/null +++ b/examples/cross_encoder/applications/cross-encoder_instruct_reranking.py @@ -0,0 +1,135 @@ +""" +This example demonstrates how to use the CrossEncoder with instruction-tuned models like Qwen-reranker or BGE-reranker. +The new `prompt_template` and `prompt_template_kwargs` arguments in the `predict` and `rank` methods allow for +flexible and dynamic formatting of the input for such models. + +This script covers three main scenarios: +1. Ranking without any template (baseline). +2. Ranking with a `prompt_template` provided at runtime. +3. Ranking with a dynamic `instruction` passed via `prompt_template_kwargs`. + +Finally, it provides a guide on how to set a default prompt template in the model's `config.json`. +""" + +from sentence_transformers.cross_encoder import CrossEncoder + +# We use a Qwen Reranker model here. In a real-world scenario, this could also be +# an instruction-tuned model like 'BAAI/bge-reranker-large'. +model = CrossEncoder("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", trust_remote_code=True) +model.model.config.pad_token_id = model.tokenizer.pad_token_id + +query = "What is the capital of China?" +documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", +] + +# First, we create the sentence pairs for the query and all documents +sentence_pairs = [[query, doc] for doc in documents] + +print("--- 1. Reranking without any template (Incorrect Usage of Qwen3 Reranker) ---") +# The model receives the plain query and document pairs. +baseline_scores = model.predict(sentence_pairs, convert_to_numpy=True) +scored_docs = sorted(zip(baseline_scores, documents), key=lambda x: x[0], reverse=True) + +print("Query:", query) +for score, doc in scored_docs: + print(f"{score:.4f}\t{doc}") + +print("\n\n--- 2. Reranking with a runtime prompt_template ---") +# The query and document are formatted using the template before being passed to the model. +# This changes the input text and thus the resulting scores. +prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' +suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" +instruction = "Given a web search query, retrieve relevant passages that answer the query" +query_template = f"{prefix}: {instruction}\n: {{query}}\n" +document_template = f": {{document}}{suffix}" + +template = query_template + document_template +template_scores = model.predict(sentence_pairs, prompt_template=template) +scored_docs_template = sorted(zip(template_scores, documents), key=lambda x: x[0], reverse=True) + +print("Using template:", template) +print("Query:", query) +for score, doc in scored_docs_template: + print(f"{score:.4f}\t{doc}") +# The scores will be different from the baseline because the model processes a different text. + +print("\n\n--- 3. Reranking with a dynamic instruction ---") +# This is useful for models that expect a specific instruction. +# The instruction can be changed at runtime. +prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' +suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" +instruct_template = f"{prefix}: {{instruction}}\n: {{query}}\n: {{document}}{suffix}" +instruct_kwargs_1 = {"instruction": "Given a query, find the most relevant document."} +instruct_kwargs_2 = {"instruction": "Given a question, find the incorrect answer."} # Misleading instruction + +print(f"Using template: {instruct_template}") +print(f"With instruction 1: '{instruct_kwargs_1['instruction']}'") +instruction_scores_1 = model.predict( + sentence_pairs, prompt_template=instruct_template, prompt_template_kwargs=instruct_kwargs_1 +) +scored_docs_instruct_1 = sorted(zip(instruction_scores_1, documents), key=lambda x: x[0], reverse=True) +for score, doc in scored_docs_instruct_1: + print(f"{score:.4f}\t{doc}") + +print(f"\nWith instruction 2: '{instruct_kwargs_2['instruction']}'") +instruction_scores_2 = model.predict( + sentence_pairs, prompt_template=instruct_template, prompt_template_kwargs=instruct_kwargs_2 +) +scored_docs_instruct_2 = sorted(zip(instruction_scores_2, documents), key=lambda x: x[0], reverse=True) +for score, doc in scored_docs_instruct_2: + print(f"{score:.4f}\t{doc}") +# The scores for instruction 1 and 2 will likely differ, as the instruction text changes the input. + +# --- 4. Guide: Setting a Default Prompt Template in config.json --- +# +# If you are a model creator or want to use a specific prompt format consistently +# without passing it in every `rank` or `predict` call, you can set a default +# template in the model's `config.json` file. +# +# Step 1: Save your base model to a directory. +# +# from sentence_transformers import CrossEncoder +# import json +# +# model = CrossEncoder("your-base-model-name") +# save_path = "path/to/your-instruct-model" +# model.save(save_path) +# +# Step 2: Modify the `config.json` in the saved directory. +# Add the "prompt_template" and "prompt_template_kwargs" keys to the +# "sentence_transformers" dictionary. +# +# // path/to/your-instruct-model/config.json +# { +# ... +# "sentence_transformers": { +# "version": "3.0.0.dev0", +# "prompt_template": "Instruct: {instruction}\nQuery: {query}\nDocument: {document}", +# "prompt_template_kwargs": { +# "instruction": "Given a query, find the most relevant document." +# } +# }, +# ... +# } +# +# Step 3: Load the model from the modified path. +# It will now use the default template automatically. +# +# instruct_model = CrossEncoder(save_path, trust_remote_code=True) +# sentence_pairs = [[query, doc] for doc in documents] +# scores = instruct_model.predict(sentence_pairs) +# +# # This call is now equivalent to calling the original model with the full template arguments: +# # original_model.predict(sentence_pairs, +# # prompt_template="Instruct: {instruction}\nQuery: {query}\nDocument: {document}", +# # prompt_template_kwargs={"instruction": "Given a query, find the most relevant document."}) +# +# You can still override the default template by passing arguments at runtime: +# +# # This will use the new instruction, overriding the default one. +# scores_new_instruction = instruct_model.predict( +# sentence_pairs, +# prompt_template_kwargs={"instruction": "Find the answer to the question."} +# ) diff --git a/sentence_transformers/cross_encoder/CrossEncoder.py b/sentence_transformers/cross_encoder/CrossEncoder.py index 07b7bfc93..f10de22d3 100644 --- a/sentence_transformers/cross_encoder/CrossEncoder.py +++ b/sentence_transformers/cross_encoder/CrossEncoder.py @@ -170,6 +170,12 @@ def __init__( if config.architectures is not None: classifier_trained = any([arch.endswith("ForSequenceClassification") for arch in config.architectures]) + self.default_prompt_template: str | None = None + self.default_prompt_template_kwargs: dict[str, Any] | None = None + if hasattr(config, "sentence_transformers"): + self.default_prompt_template = config.sentence_transformers.get("prompt_template") + self.default_prompt_template_kwargs = config.sentence_transformers.get("prompt_template_kwargs") + if num_labels is None and not classifier_trained: num_labels = 1 @@ -538,6 +544,8 @@ def predict( apply_softmax: bool | None = ..., convert_to_numpy: Literal[False] = ..., convert_to_tensor: Literal[False] = ..., + prompt_template: str | None = None, + prompt_template_kwargs: dict[str, Any] | None = None, ) -> torch.Tensor: ... @overload @@ -550,6 +558,8 @@ def predict( apply_softmax: bool | None = ..., convert_to_numpy: Literal[True] = True, convert_to_tensor: Literal[False] = False, + prompt_template: str | None = None, + prompt_template_kwargs: dict[str, Any] | None = None, ) -> np.ndarray: ... @overload @@ -562,6 +572,8 @@ def predict( apply_softmax: bool | None = ..., convert_to_numpy: bool = ..., convert_to_tensor: Literal[True] = ..., + prompt_template: str | None = None, + prompt_template_kwargs: dict[str, Any] | None = None, ) -> torch.Tensor: ... @overload @@ -574,6 +586,8 @@ def predict( apply_softmax: bool | None = ..., convert_to_numpy: Literal[False] = ..., convert_to_tensor: Literal[False] = ..., + prompt_template: str | None = None, + prompt_template_kwargs: dict[str, Any] | None = None, ) -> list[torch.Tensor]: ... @torch.inference_mode() @@ -587,6 +601,8 @@ def predict( apply_softmax: bool | None = False, convert_to_numpy: bool = True, convert_to_tensor: bool = False, + prompt_template: str | None = None, + prompt_template_kwargs: dict[str, Any] | None = None, ) -> list[torch.Tensor] | np.ndarray | torch.Tensor: """ Performs predictions with the CrossEncoder on the given sentence pairs. @@ -599,13 +615,17 @@ def predict( activation_fn (callable, optional): Activation function applied on the logits output of the CrossEncoder. If None, the ``model.activation_fn`` will be used, which defaults to :class:`torch.nn.Sigmoid` if num_labels=1, else :class:`torch.nn.Identity`. Defaults to None. - convert_to_numpy (bool, optional): Convert the output to a numpy matrix. Defaults to True. apply_softmax (bool, optional): If set to True and `model.num_labels > 1`, applies softmax on the logits output such that for each sample, the scores of each class sum to 1. Defaults to False. convert_to_numpy (bool, optional): Whether the output should be a list of numpy vectors. If False, output a list of PyTorch tensors. Defaults to True. convert_to_tensor (bool, optional): Whether the output should be one large tensor. Overwrites `convert_to_numpy`. Defaults to False. + prompt_template (str, optional): A template to format the input sentence pairs. The template should have placeholders + for `{query}` and `{document}`. For example: "Query: {query} Document: {document}". + prompt_template_kwargs (dict[str, Any], optional): A dictionary of keyword arguments to format the prompt template. + For example, you can provide an instruction: `{"instruction": "Determine the relevance."}` for a template like + "Instruct: {instruction} Query: {query} Document: {document}". Returns: Union[List[torch.Tensor], np.ndarray, torch.Tensor]: Predictions for the passed sentence pairs. @@ -634,6 +654,28 @@ def predict( logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG ) + # If prompt_template is provided or default_prompt_template in config.json is set, use it to format the input sentence pairs + final_prompt_template = prompt_template if prompt_template is not None else self.default_prompt_template + if final_prompt_template: + final_kwargs = self.default_prompt_template_kwargs.copy() if self.default_prompt_template_kwargs else {} + if prompt_template_kwargs: + # Update final_kwargs with any additional keyword arguments provided in prompt_template_kwargs + final_kwargs.update(prompt_template_kwargs) + + formatted_sentences = [] + try: + for query, doc in sentences: + all_kwargs = {"query": query, "document": doc, **final_kwargs} + formatted_sentences.append(final_prompt_template.format(**all_kwargs)) + except KeyError as e: + # If a placeholder in the prompt template is not valid, raise an error + available_keys = ["query", "document"] + list(final_kwargs.keys()) + raise KeyError( + f"A placeholder in the prompt template is not valid. The placeholder {e} was not found. " + f"Available placeholders are: {', '.join(sorted(list(set(available_keys))))}." + ) from e + sentences = formatted_sentences + if activation_fn is not None: self.set_activation_fn(activation_fn, set_default=False) @@ -681,6 +723,8 @@ def rank( apply_softmax=False, convert_to_numpy: bool = True, convert_to_tensor: bool = False, + prompt_template: str | None = None, + prompt_template_kwargs: dict[str, Any] | None = None, ) -> list[dict[Literal["corpus_id", "score", "text"], int | float | str]]: """ Performs ranking with the CrossEncoder on the given query and documents. Returns a sorted list with the document indices and scores. @@ -696,6 +740,11 @@ def rank( convert_to_numpy (bool, optional): Convert the output to a numpy matrix. Defaults to True. apply_softmax (bool, optional): If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output. Defaults to False. convert_to_tensor (bool, optional): Convert the output to a tensor. Defaults to False. + prompt_template (str, optional): A template to format the input sentence pairs. The template should have placeholders + for `{query}` and `{document}`. For example: "Query: {query} Document: {document}". + prompt_template_kwargs (dict[str, Any], optional): A dictionary of keyword arguments to format the prompt template. + For example, you can provide an instruction: `{"instruction": "Determine the relevance."}` for a template like + "Instruct: {instruction} Query: {query} Document: {document}". Returns: List[Dict[Literal["corpus_id", "score", "text"], Union[int, float, str]]]: A sorted list with the "corpus_id", "score", and optionally "text" of the documents. @@ -750,6 +799,8 @@ def rank( apply_softmax=apply_softmax, convert_to_numpy=convert_to_numpy, convert_to_tensor=convert_to_tensor, + prompt_template=prompt_template, + prompt_template_kwargs=prompt_template_kwargs, ) results = [] diff --git a/tests/cross_encoder/test_cross_encoder.py b/tests/cross_encoder/test_cross_encoder.py index f8d5969c5..b7135f628 100644 --- a/tests/cross_encoder/test_cross_encoder.py +++ b/tests/cross_encoder/test_cross_encoder.py @@ -132,6 +132,81 @@ def test_predict_softmax(): assert not torch.isclose(scores.sum(1), torch.ones(len(corpus), device=scores.device)).all() +def test_predict_with_prompt_template(): + model = CrossEncoder("cross-encoder-testing/reranker-bert-tiny-gooaq-bce") + query = "A man is eating pasta." + corpus = [ + "A man is eating food.", + "A woman is playing violin.", + ] + pairs = [[query, doc] for doc in corpus] + prompt_template = ( + "Instruct: Given a query and a document, determine if they are relevant. Query: {query} Document: {document}" + ) + + # 1. Test with prompt_template in predict + scores_prompt = model.predict(pairs, prompt_template=prompt_template) + + # 2. Test without prompt_template + scores_no_prompt = model.predict(pairs) + + # The scores should be different as the input to the model is different + assert not np.allclose(scores_prompt, scores_no_prompt) + + # 3. Test with prompt_template in rank + ranks_prompt = model.rank(query, corpus, prompt_template=prompt_template) + ranks_no_prompt = model.rank(query, corpus) + + assert ranks_prompt[0]["score"] != ranks_no_prompt[0]["score"] + assert ranks_prompt[1]["score"] != ranks_no_prompt[1]["score"] + + +def test_predict_with_default_prompt_template(tmp_path: Path): + # 1. Create a base model and save it + model_name = "cross-encoder-testing/reranker-bert-tiny-gooaq-bce" + original_model = CrossEncoder(model_name) + save_path = tmp_path / "model_with_template" + original_model.save(str(save_path)) + + # 2. Modify the config.json to add a default prompt template + config_path = save_path / "config.json" + with open(config_path) as f: + config = json.load(f) + + prompt_template = "Instruct: {instruction} Query: {query} Document: {document}" + prompt_kwargs = {"instruction": "Determine relevance."} + if "sentence_transformers" not in config: + config["sentence_transformers"] = {} + config["sentence_transformers"]["prompt_template"] = prompt_template + config["sentence_transformers"]["prompt_template_kwargs"] = prompt_kwargs + + with open(config_path, "w") as f: + json.dump(config, f) + + # 3. Load the model with the modified config + model_with_template = CrossEncoder(str(save_path)) + assert model_with_template.default_prompt_template == prompt_template + assert model_with_template.default_prompt_template_kwargs == prompt_kwargs + + # 4. Perform prediction and compare results + query = "A man is eating pasta." + doc = "A man is eating food." + + # Prediction with the model that has a default template + scores_with_default_template = model_with_template.predict([[query, doc]]) + + # Prediction with the original model (no template) + scores_original = original_model.predict([[query, doc]]) + + # The scores should be different because one uses the template and the other doesn't. + assert not np.allclose(scores_with_default_template, scores_original) + + # 5. Test that runtime arguments can overwrite the default template + runtime_template = "Query: {query} and Doc: {document}" + scores_runtime_template = model_with_template.predict([[query, doc]], prompt_template=runtime_template) + assert not np.allclose(scores_with_default_template, scores_runtime_template) + + @pytest.mark.parametrize( "model_name", ["cross-encoder-testing/reranker-bert-tiny-gooaq-bce", "cross-encoder/nli-MiniLM2-L6-H768"] )