Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions libs/community/langchain_community/cross_encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from langchain_community.cross_encoders.huggingface import (
HuggingFaceCrossEncoder,
)
from langchain_community.cross_encoders.mixedbreadai import (
MixedbreadAICrossEncoder,
)
from langchain_community.cross_encoders.sagemaker_endpoint import (
SagemakerEndpointCrossEncoder,
)
Expand All @@ -32,13 +35,15 @@
"FakeCrossEncoder",
"HuggingFaceCrossEncoder",
"SagemakerEndpointCrossEncoder",
"MixedbreadAICrossEncoder",
]

_module_lookup = {
"BaseCrossEncoder": "langchain_community.cross_encoders.base",
"FakeCrossEncoder": "langchain_community.cross_encoders.fake",
"HuggingFaceCrossEncoder": "langchain_community.cross_encoders.huggingface",
"SagemakerEndpointCrossEncoder": "langchain_community.cross_encoders.sagemaker_endpoint", # noqa: E501
"MixedbreadAICrossEncoder": "langchain_community.cross_encoders.mixedbreadai",
}


Expand Down
108 changes: 108 additions & 0 deletions libs/community/langchain_community/cross_encoders/mixedbreadai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Any, Dict, List, Tuple

from pydantic import BaseModel, ConfigDict, Field

from collections import defaultdict

from langchain_community.cross_encoders.base import BaseCrossEncoder

Check failure on line 7 in libs/community/langchain_community/cross_encoders/mixedbreadai.py

View workflow job for this annotation

GitHub Actions / cd libs/community / Python 3.11

Ruff (I001)

langchain_community/cross_encoders/mixedbreadai.py:1:1: I001 Import block is un-sorted or un-formatted

DEFAULT_MODEL_NAME = "mixedbread-ai/mxbai-rerank-base-v2"


class MixedbreadAICrossEncoder(BaseModel, BaseCrossEncoder):
"""Mixedbread cross encoder models.
Args:
model_name: The name or identifier of the Mixedbread AI model to use.
model_kwargs: Additional keyword arguments to pass to the model.
normalize_scores: Whether to normalize the scores returned by the model.
Defaults to True.
Example:
.. code-block:: python
from langchain_community.cross_encoders import MixedbreadAICrossEncoder
model_name = "mixedbread-ai/mxbai-rerank-base-v2"
model_kwargs = {'top_k': 10}
mb = MixedbreadAICrossEncoder(
model_name=model_name,
model_kwargs=model_kwargs
)
"""

client: Any = None #: :meta private:
model_name: str = DEFAULT_MODEL_NAME
"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
normalize_scores: bool = Field(default=True)
"""Whether to normalize scores to [0, 1] range."""

def __init__(self, **kwargs: Any):
"""Initialize the Mixedbread reranker."""
super().__init__(**kwargs)
try:
from mxbai_rerank import MxbaiRerankV2
except ImportError as exc:
raise ImportError(
"Could not import mxbai_rerank python package. "
"Please install it with `pip install mxbai-rerank`."
) from exc

self.client = MxbaiRerankV2(self.model_name, **self.model_kwargs)

model_config = ConfigDict(extra="forbid", protected_namespaces=())

def _normalize_scores(self, scores: List[float]) -> List[float]:
"""Normalise scores to [0, 1] range."""
if not scores:
return scores

min_score = min(scores)
max_score = max(scores)

if max_score == min_score:
return [0.0] * len(scores)

return [(score - min_score) / (max_score - min_score) for score in scores]

def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
"""Compute similarity scores using a Mixedbread transformer model.
Args:
text_pairs: The list of text pairs to score the similarity.
Each tuple should be (query, document).
Returns:
List of similarity/relevance scores between query-document pairs,
one float score for each input pair.
"""
if not text_pairs:
return []

query_groups = defaultdict(list)
for i, (query, doc) in enumerate(text_pairs):
query_groups[query].append((i, doc))


# Process each query group
scores = [0.0] * len(text_pairs)
for query, doc_entries in query_groups.items():
indices = [i for i, _ in doc_entries]
documents = [doc for _, doc in doc_entries]

results = self.client.rank(
query,
documents,
return_documents=False,
top_k=len(documents)
)

# Map scores back to original positions
for res_idx, result in enumerate(results):
orig_idx = indices[result.index]
scores[orig_idx] = result.score

# Normalize scores if requested
if self.normalize_scores:
scores = self._normalize_scores(scores)

return scores
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Test mixedbreadai cross encoders."""

from langchain_community.cross_encoders import MixedbreadAICrossEncoder


def _assert(encoder: MixedbreadAICrossEncoder) -> None:
query = "I love you"
texts = ["I love you", "I like you", "I don't like you", "I hate you"]
output = encoder.score([(query, text) for text in texts])

# Check that we got scores for all texts
assert len(output) == len(texts)

# Check that scores are in descending order (most relevant first)
for i in range(len(texts) - 1):
assert output[i] > output[i + 1], f"Score at index {i} ({output[i]}) should be greater than score at index {i+1} ({output[i+1]})"


def test_mixedbreadai_cross_encoder() -> None:
"""Test MixedbreadAICrossEncoder with default model."""
encoder = MixedbreadAICrossEncoder()
_assert(encoder)


def test_mixedbreadai_cross_encoder_with_designated_model_name() -> None:
"""Test MixedbreadAICrossEncoder with specific model."""
encoder = MixedbreadAICrossEncoder(model_name="mixedbread-ai/mxbai-rerank-base-v1")
_assert(encoder)


def test_mixedbreadai_cross_encoder_without_normalization() -> None:
"""Test MixedbreadAICrossEncoder with normalization disabled."""
encoder = MixedbreadAICrossEncoder(normalize_scores=False)
_assert(encoder)



def test_mixedbreadai_cross_encoder_multilingual() -> None:
"""Test MixedbreadAICrossEncoder with multilingual content."""
encoder = MixedbreadAICrossEncoder()
query = "¿Cómo afecta la agricultura al clima?"
texts = [
"El cambio climático provoca sequías e inundaciones que afectan los cultivos.",
"Climate change leads to droughts and floods, affecting crop yields.",
"Agriculture is impacted by rising temperatures and unpredictable weather."
]
output = encoder.score([(query, text) for text in texts])
assert len(output) == len(texts)

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"FakeCrossEncoder",
"HuggingFaceCrossEncoder",
"SagemakerEndpointCrossEncoder",
"MixedbreadAICrossEncoder",
]


Expand Down
Loading