diff --git a/libs/community/langchain_community/cross_encoders/__init__.py b/libs/community/langchain_community/cross_encoders/__init__.py index 65d304ee1..fbfd1da57 100644 --- a/libs/community/langchain_community/cross_encoders/__init__.py +++ b/libs/community/langchain_community/cross_encoders/__init__.py @@ -23,6 +23,9 @@ from langchain_community.cross_encoders.huggingface import ( HuggingFaceCrossEncoder, ) + from langchain_community.cross_encoders.mixedbreadai import ( + MixedbreadAICrossEncoder, + ) from langchain_community.cross_encoders.sagemaker_endpoint import ( SagemakerEndpointCrossEncoder, ) @@ -32,6 +35,7 @@ "FakeCrossEncoder", "HuggingFaceCrossEncoder", "SagemakerEndpointCrossEncoder", + "MixedbreadAICrossEncoder", ] _module_lookup = { @@ -39,6 +43,7 @@ "FakeCrossEncoder": "langchain_community.cross_encoders.fake", "HuggingFaceCrossEncoder": "langchain_community.cross_encoders.huggingface", "SagemakerEndpointCrossEncoder": "langchain_community.cross_encoders.sagemaker_endpoint", # noqa: E501 + "MixedbreadAICrossEncoder": "langchain_community.cross_encoders.mixedbreadai", } diff --git a/libs/community/langchain_community/cross_encoders/mixedbreadai.py b/libs/community/langchain_community/cross_encoders/mixedbreadai.py new file mode 100644 index 000000000..389fc4f15 --- /dev/null +++ b/libs/community/langchain_community/cross_encoders/mixedbreadai.py @@ -0,0 +1,108 @@ +from typing import Any, Dict, List, Tuple + +from pydantic import BaseModel, ConfigDict, Field + +from collections import defaultdict + +from langchain_community.cross_encoders.base import BaseCrossEncoder + +DEFAULT_MODEL_NAME = "mixedbread-ai/mxbai-rerank-base-v2" + + +class MixedbreadAICrossEncoder(BaseModel, BaseCrossEncoder): + """Mixedbread cross encoder models. + Args: + model_name: The name or identifier of the Mixedbread AI model to use. + model_kwargs: Additional keyword arguments to pass to the model. + normalize_scores: Whether to normalize the scores returned by the model. + Defaults to True. + + Example: + .. code-block:: python + from langchain_community.cross_encoders import MixedbreadAICrossEncoder + model_name = "mixedbread-ai/mxbai-rerank-base-v2" + model_kwargs = {'top_k': 10} + mb = MixedbreadAICrossEncoder( + model_name=model_name, + model_kwargs=model_kwargs + ) + """ + + client: Any = None #: :meta private: + model_name: str = DEFAULT_MODEL_NAME + """Model name to use.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass to the model.""" + normalize_scores: bool = Field(default=True) + """Whether to normalize scores to [0, 1] range.""" + + def __init__(self, **kwargs: Any): + """Initialize the Mixedbread reranker.""" + super().__init__(**kwargs) + try: + from mxbai_rerank import MxbaiRerankV2 + except ImportError as exc: + raise ImportError( + "Could not import mxbai_rerank python package. " + "Please install it with `pip install mxbai-rerank`." + ) from exc + + self.client = MxbaiRerankV2(self.model_name, **self.model_kwargs) + + model_config = ConfigDict(extra="forbid", protected_namespaces=()) + + def _normalize_scores(self, scores: List[float]) -> List[float]: + """Normalise scores to [0, 1] range.""" + if not scores: + return scores + + min_score = min(scores) + max_score = max(scores) + + if max_score == min_score: + return [0.0] * len(scores) + + return [(score - min_score) / (max_score - min_score) for score in scores] + + def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: + """Compute similarity scores using a Mixedbread transformer model. + + Args: + text_pairs: The list of text pairs to score the similarity. + Each tuple should be (query, document). + + Returns: + List of similarity/relevance scores between query-document pairs, + one float score for each input pair. + """ + if not text_pairs: + return [] + + query_groups = defaultdict(list) + for i, (query, doc) in enumerate(text_pairs): + query_groups[query].append((i, doc)) + + + # Process each query group + scores = [0.0] * len(text_pairs) + for query, doc_entries in query_groups.items(): + indices = [i for i, _ in doc_entries] + documents = [doc for _, doc in doc_entries] + + results = self.client.rank( + query, + documents, + return_documents=False, + top_k=len(documents) + ) + + # Map scores back to original positions + for res_idx, result in enumerate(results): + orig_idx = indices[result.index] + scores[orig_idx] = result.score + + # Normalize scores if requested + if self.normalize_scores: + scores = self._normalize_scores(scores) + + return scores diff --git a/libs/community/tests/integration_tests/cross_encoders/test_mixedbreadai.py b/libs/community/tests/integration_tests/cross_encoders/test_mixedbreadai.py new file mode 100644 index 000000000..32b496eba --- /dev/null +++ b/libs/community/tests/integration_tests/cross_encoders/test_mixedbreadai.py @@ -0,0 +1,49 @@ +"""Test mixedbreadai cross encoders.""" + +from langchain_community.cross_encoders import MixedbreadAICrossEncoder + + +def _assert(encoder: MixedbreadAICrossEncoder) -> None: + query = "I love you" + texts = ["I love you", "I like you", "I don't like you", "I hate you"] + output = encoder.score([(query, text) for text in texts]) + + # Check that we got scores for all texts + assert len(output) == len(texts) + + # Check that scores are in descending order (most relevant first) + for i in range(len(texts) - 1): + assert output[i] > output[i + 1], f"Score at index {i} ({output[i]}) should be greater than score at index {i+1} ({output[i+1]})" + + +def test_mixedbreadai_cross_encoder() -> None: + """Test MixedbreadAICrossEncoder with default model.""" + encoder = MixedbreadAICrossEncoder() + _assert(encoder) + + +def test_mixedbreadai_cross_encoder_with_designated_model_name() -> None: + """Test MixedbreadAICrossEncoder with specific model.""" + encoder = MixedbreadAICrossEncoder(model_name="mixedbread-ai/mxbai-rerank-base-v1") + _assert(encoder) + + +def test_mixedbreadai_cross_encoder_without_normalization() -> None: + """Test MixedbreadAICrossEncoder with normalization disabled.""" + encoder = MixedbreadAICrossEncoder(normalize_scores=False) + _assert(encoder) + + + +def test_mixedbreadai_cross_encoder_multilingual() -> None: + """Test MixedbreadAICrossEncoder with multilingual content.""" + encoder = MixedbreadAICrossEncoder() + query = "¿Cómo afecta la agricultura al clima?" + texts = [ + "El cambio climático provoca sequías e inundaciones que afectan los cultivos.", + "Climate change leads to droughts and floods, affecting crop yields.", + "Agriculture is impacted by rising temperatures and unpredictable weather." + ] + output = encoder.score([(query, text) for text in texts]) + assert len(output) == len(texts) + diff --git a/libs/community/tests/unit_tests/cross_encoders/test_imports.py b/libs/community/tests/unit_tests/cross_encoders/test_imports.py index 5de7395f1..2e680f3c4 100644 --- a/libs/community/tests/unit_tests/cross_encoders/test_imports.py +++ b/libs/community/tests/unit_tests/cross_encoders/test_imports.py @@ -5,6 +5,7 @@ "FakeCrossEncoder", "HuggingFaceCrossEncoder", "SagemakerEndpointCrossEncoder", + "MixedbreadAICrossEncoder", ]