Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions libs/community/langchain_community/cross_encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from langchain_community.cross_encoders.huggingface import (
HuggingFaceCrossEncoder,
)
from langchain_community.cross_encoders.mixedbreadai import (
MixedbreadAICrossEncoder,
)
from langchain_community.cross_encoders.sagemaker_endpoint import (
SagemakerEndpointCrossEncoder,
)
Expand All @@ -32,13 +35,15 @@
"FakeCrossEncoder",
"HuggingFaceCrossEncoder",
"SagemakerEndpointCrossEncoder",
"MixedbreadAICrossEncoder",
]

_module_lookup = {
"BaseCrossEncoder": "langchain_community.cross_encoders.base",
"FakeCrossEncoder": "langchain_community.cross_encoders.fake",
"HuggingFaceCrossEncoder": "langchain_community.cross_encoders.huggingface",
"SagemakerEndpointCrossEncoder": "langchain_community.cross_encoders.sagemaker_endpoint", # noqa: E501
"MixedbreadAICrossEncoder": "langchain_community.cross_encoders.mixedbreadai",
}


Expand Down
95 changes: 95 additions & 0 deletions libs/community/langchain_community/cross_encoders/mixedbreadai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Any, Dict, List, Tuple

from pydantic import BaseModel, ConfigDict, Field

from langchain_community.cross_encoders.base import BaseCrossEncoder

DEFAULT_MODEL_NAME = "mixedbread-ai/mxbai-rerank-base-v2"


class MixedbreadAICrossEncoder(BaseModel, BaseCrossEncoder):
"""Mixbread cross encoder models.

Example:
.. code-block:: python

from langchain_community.cross_encoders import MixedbreadAICrossEncoder

model_name = "mixedbread-ai/mxbai-rerank-base-v2"
model_kwargs = {'top_k': 10}
mb = MixedbreadAICrossEncoder(
model_name=model_name,
model_kwargs=model_kwargs
)
"""

client: Any = None #: :meta private:
model_name: str = DEFAULT_MODEL_NAME
"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
normalize_scores: bool = Field(default=True)
"""Whether to normalize scores to [0, 1] range."""

def __init__(self, **kwargs: Any):
"""Initialize the mixbread reranker."""
super().__init__(**kwargs)
try:
from mxbai_rerank import MxbaiRerankV2
except ImportError as exc:
raise ImportError(
"Could not import mxbai_rerank python package. "
"Please install it with `pip install mxbai-rerank`."
) from exc

self.client = MxbaiRerankV2(self.model_name,**self.model_kwargs)

model_config = ConfigDict(extra="forbid", protected_namespaces=())

def _normalize_scores(self, scores: List[float]) -> List[float]:
"""Normalize scores to [0, 1] range."""
if not scores:
return scores

min_score = min(scores)
max_score = max(scores)

if max_score == min_score:
return [0.0] * len(scores)

return [(score - min_score) / (max_score - min_score) for score in scores]

def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
"""Compute similarity scores using a Mixbread transformer model.

Args:
text_pairs: The list of text pairs to score the similarity.
Each tuple should be (query, document).

Returns:
List of scores, one for each pair.
"""
if not text_pairs:
return []

# Extract query and documents (assuming single query)
query = text_pairs[0][0]
documents = [pair[1] for pair in text_pairs]

# Single API call for all documents
results = self.client.rank(
query,
documents,
return_documents=False,
top_k=len(documents)
)

# Create score mapping and preserve original order
score_map = {res.index: res.score for res in results}
scores = [score_map.get(i, 0.0) for i in range(len(documents))]

# Normalize scores if requested
if self.normalize_scores:
scores = self._normalize_scores(scores)

return scores
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Test mixedbreadai cross encoders."""

from langchain_community.cross_encoders import MixedbreadAICrossEncoder


def _assert(encoder: MixedbreadAICrossEncoder) -> None:
query = "I love you"
texts = ["I love you", "I like you", "I don't like you", "I hate you"]
output = encoder.score([(query, text) for text in texts])

# Check that we got scores for all texts
assert len(output) == len(texts)

# Check that scores are in descending order (most relevant first)
for i in range(len(texts) - 1):
assert output[i] > output[i + 1], f"Score at index {i} ({output[i]}) should be greater than score at index {i+1} ({output[i+1]})"


def test_mixedbreadai_cross_encoder() -> None:
"""Test MixedbreadAICrossEncoder with default model."""
encoder = MixedbreadAICrossEncoder()
_assert(encoder)


def test_mixedbreadai_cross_encoder_with_designated_model_name() -> None:
"""Test MixedbreadAICrossEncoder with specific model."""
encoder = MixedbreadAICrossEncoder(model_name="mixedbread-ai/mxbai-rerank-base-v1")
_assert(encoder)


def test_mixedbreadai_cross_encoder_without_normalization() -> None:
"""Test MixedbreadAICrossEncoder with normalization disabled."""
encoder = MixedbreadAICrossEncoder(normalize_scores=False)
_assert(encoder)


def test_mixedbreadai_cross_encoder_with_model_kwargs() -> None:
"""Test MixedbreadAICrossEncoder with model kwargs."""
encoder = MixedbreadAICrossEncoder(
model_name="mixedbread-ai/mxbai-rerank-large-v2",
model_kwargs={}, # mixedbreadai doesn't use traditional model_kwargs like HF
normalize_scores=True
)
_assert(encoder)


def test_mixedbreadai_cross_encoder_empty_input() -> None:
"""Test MixedbreadAICrossEncoder with empty input."""
encoder = MixedbreadAICrossEncoder()
output = encoder.score([])
assert output == []


def test_mixedbreadai_cross_encoder_single_text() -> None:
"""Test MixedbreadAICrossEncoder with single text."""
encoder = MixedbreadAICrossEncoder()
query = "Hello world"
texts = ["Hello world"]
output = encoder.score([(query, text) for text in texts])

assert len(output) == 1
assert isinstance(output[0], float)


def test_mixedbreadai_cross_encoder_multilingual() -> None:
"""Test MixedbreadAICrossEncoder with multilingual content."""
encoder = MixedbreadAICrossEncoder()
query = "¿Cómo afecta la agricultura al clima?"
texts = [
"El cambio climático provoca sequías e inundaciones que afectan los cultivos.",
"Climate change leads to droughts and floods, affecting crop yields.",
"Agriculture is impacted by rising temperatures and unpredecable weather."
]
output = encoder.score([(query, text) for text in texts])

assert len(output) == len(texts)
# Spanish text should rank higher than English for Spanish query
assert output[0] > output[3] # Spanish agriculture text > irrelevant text
# Agriculture-related content should rank higher than unrelated content
assert output[2] > output[3] # Agriculture text > Data scientist text