generated from langchain-ai/integration-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 269
feat:(integration): add MixedbreadAICrossEncoder for reranking documents #274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
adityaiiitmk
wants to merge
14
commits into
langchain-ai:main
Choose a base branch
from
adityaiiitmk:aditya/mixedbreadai
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+163
−0
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
4fa5765
feat:(integration): add MixedbreadAICrossEncoder for reranking docum…
adityaiiitmk 2a8a0eb
test(integration): add tests for MixedbreadAICrossEncoder functionality
adityaiiitmk 89a69ac
Update libs/community/langchain_community/cross_encoders/mixedbreadai.py
adityaiiitmk 91bf9c4
Update libs/community/langchain_community/cross_encoders/mixedbreadai.py
adityaiiitmk a72c34b
Update libs/community/langchain_community/cross_encoders/mixedbreadai.py
adityaiiitmk c447cc7
Update libs/community/tests/integration_tests/cross_encoders/test_mix…
adityaiiitmk 9c49558
fix:Update test_mixedbreadai.py
adityaiiitmk 18f0fd1
fix:Update test_mixedbreadai.py
adityaiiitmk b9d6e07
fix : Update mixedbreadai.py based on code review
adityaiiitmk 40cbdba
fix: Correct docstring and import for MixedbreadAICrossEncoder
adityaiiitmk 7bac77a
Merge branch 'main' into aditya/mixedbreadai
adityaiiitmk 7b71897
Merge branch 'main' into aditya/mixedbreadai
adityaiiitmk 114af4d
Merge branch 'main' into aditya/mixedbreadai
mdrxy 029450c
Merge branch 'main' into aditya/mixedbreadai
mdrxy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
108 changes: 108 additions & 0 deletions
108
libs/community/langchain_community/cross_encoders/mixedbreadai.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| from typing import Any, Dict, List, Tuple | ||
|
|
||
| from pydantic import BaseModel, ConfigDict, Field | ||
|
|
||
| from collections import defaultdict | ||
|
|
||
| from langchain_community.cross_encoders.base import BaseCrossEncoder | ||
|
|
||
| DEFAULT_MODEL_NAME = "mixedbread-ai/mxbai-rerank-base-v2" | ||
|
|
||
|
|
||
| class MixedbreadAICrossEncoder(BaseModel, BaseCrossEncoder): | ||
| """Mixedbread cross encoder models. | ||
| Args: | ||
| model_name: The name or identifier of the Mixedbread AI model to use. | ||
| model_kwargs: Additional keyword arguments to pass to the model. | ||
| normalize_scores: Whether to normalize the scores returned by the model. | ||
| Defaults to True. | ||
| Example: | ||
| .. code-block:: python | ||
| from langchain_community.cross_encoders import MixedbreadAICrossEncoder | ||
| model_name = "mixedbread-ai/mxbai-rerank-base-v2" | ||
| model_kwargs = {'top_k': 10} | ||
| mb = MixedbreadAICrossEncoder( | ||
| model_name=model_name, | ||
| model_kwargs=model_kwargs | ||
| ) | ||
| """ | ||
|
|
||
| client: Any = None #: :meta private: | ||
| model_name: str = DEFAULT_MODEL_NAME | ||
| """Model name to use.""" | ||
| model_kwargs: Dict[str, Any] = Field(default_factory=dict) | ||
| """Keyword arguments to pass to the model.""" | ||
| normalize_scores: bool = Field(default=True) | ||
| """Whether to normalize scores to [0, 1] range.""" | ||
|
|
||
| def __init__(self, **kwargs: Any): | ||
| """Initialize the Mixedbread reranker.""" | ||
| super().__init__(**kwargs) | ||
| try: | ||
| from mxbai_rerank import MxbaiRerankV2 | ||
| except ImportError as exc: | ||
| raise ImportError( | ||
| "Could not import mxbai_rerank python package. " | ||
| "Please install it with `pip install mxbai-rerank`." | ||
| ) from exc | ||
|
|
||
| self.client = MxbaiRerankV2(self.model_name, **self.model_kwargs) | ||
|
|
||
| model_config = ConfigDict(extra="forbid", protected_namespaces=()) | ||
|
|
||
| def _normalize_scores(self, scores: List[float]) -> List[float]: | ||
| """Normalise scores to [0, 1] range.""" | ||
| if not scores: | ||
| return scores | ||
|
|
||
| min_score = min(scores) | ||
| max_score = max(scores) | ||
|
|
||
| if max_score == min_score: | ||
| return [0.0] * len(scores) | ||
|
|
||
| return [(score - min_score) / (max_score - min_score) for score in scores] | ||
|
|
||
| def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: | ||
| """Compute similarity scores using a Mixedbread transformer model. | ||
| Args: | ||
| text_pairs: The list of text pairs to score the similarity. | ||
| Each tuple should be (query, document). | ||
| Returns: | ||
| List of similarity/relevance scores between query-document pairs, | ||
| one float score for each input pair. | ||
| """ | ||
| if not text_pairs: | ||
| return [] | ||
|
|
||
| query_groups = defaultdict(list) | ||
| for i, (query, doc) in enumerate(text_pairs): | ||
| query_groups[query].append((i, doc)) | ||
|
|
||
|
|
||
| # Process each query group | ||
| scores = [0.0] * len(text_pairs) | ||
| for query, doc_entries in query_groups.items(): | ||
| indices = [i for i, _ in doc_entries] | ||
| documents = [doc for _, doc in doc_entries] | ||
|
|
||
| results = self.client.rank( | ||
| query, | ||
| documents, | ||
| return_documents=False, | ||
| top_k=len(documents) | ||
| ) | ||
|
|
||
| # Map scores back to original positions | ||
| for res_idx, result in enumerate(results): | ||
| orig_idx = indices[result.index] | ||
| scores[orig_idx] = result.score | ||
|
|
||
| # Normalize scores if requested | ||
| if self.normalize_scores: | ||
| scores = self._normalize_scores(scores) | ||
|
|
||
| return scores | ||
49 changes: 49 additions & 0 deletions
49
libs/community/tests/integration_tests/cross_encoders/test_mixedbreadai.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| """Test mixedbreadai cross encoders.""" | ||
|
|
||
| from langchain_community.cross_encoders import MixedbreadAICrossEncoder | ||
|
|
||
|
|
||
| def _assert(encoder: MixedbreadAICrossEncoder) -> None: | ||
| query = "I love you" | ||
| texts = ["I love you", "I like you", "I don't like you", "I hate you"] | ||
| output = encoder.score([(query, text) for text in texts]) | ||
|
|
||
| # Check that we got scores for all texts | ||
| assert len(output) == len(texts) | ||
|
|
||
| # Check that scores are in descending order (most relevant first) | ||
| for i in range(len(texts) - 1): | ||
| assert output[i] > output[i + 1], f"Score at index {i} ({output[i]}) should be greater than score at index {i+1} ({output[i+1]})" | ||
|
|
||
|
|
||
| def test_mixedbreadai_cross_encoder() -> None: | ||
| """Test MixedbreadAICrossEncoder with default model.""" | ||
| encoder = MixedbreadAICrossEncoder() | ||
| _assert(encoder) | ||
|
|
||
|
|
||
| def test_mixedbreadai_cross_encoder_with_designated_model_name() -> None: | ||
| """Test MixedbreadAICrossEncoder with specific model.""" | ||
| encoder = MixedbreadAICrossEncoder(model_name="mixedbread-ai/mxbai-rerank-base-v1") | ||
| _assert(encoder) | ||
|
|
||
|
|
||
| def test_mixedbreadai_cross_encoder_without_normalization() -> None: | ||
| """Test MixedbreadAICrossEncoder with normalization disabled.""" | ||
| encoder = MixedbreadAICrossEncoder(normalize_scores=False) | ||
| _assert(encoder) | ||
|
|
||
|
|
||
|
|
||
| def test_mixedbreadai_cross_encoder_multilingual() -> None: | ||
| """Test MixedbreadAICrossEncoder with multilingual content.""" | ||
| encoder = MixedbreadAICrossEncoder() | ||
| query = "¿Cómo afecta la agricultura al clima?" | ||
| texts = [ | ||
| "El cambio climático provoca sequías e inundaciones que afectan los cultivos.", | ||
| "Climate change leads to droughts and floods, affecting crop yields.", | ||
| "Agriculture is impacted by rising temperatures and unpredictable weather." | ||
| ] | ||
| output = encoder.score([(query, text) for text in texts]) | ||
| assert len(output) == len(texts) | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.