generated from langchain-ai/integration-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 269
feat:(integration): add MixedbreadAICrossEncoder for reranking documents #274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
adityaiiitmk
wants to merge
14
commits into
langchain-ai:main
Choose a base branch
from
adityaiiitmk:aditya/mixedbreadai
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 8 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
4fa5765
feat:(integration): add MixedbreadAICrossEncoder for reranking docum…
adityaiiitmk 2a8a0eb
test(integration): add tests for MixedbreadAICrossEncoder functionality
adityaiiitmk 89a69ac
Update libs/community/langchain_community/cross_encoders/mixedbreadai.py
adityaiiitmk 91bf9c4
Update libs/community/langchain_community/cross_encoders/mixedbreadai.py
adityaiiitmk a72c34b
Update libs/community/langchain_community/cross_encoders/mixedbreadai.py
adityaiiitmk c447cc7
Update libs/community/tests/integration_tests/cross_encoders/test_mix…
adityaiiitmk 9c49558
fix:Update test_mixedbreadai.py
adityaiiitmk 18f0fd1
fix:Update test_mixedbreadai.py
adityaiiitmk b9d6e07
fix : Update mixedbreadai.py based on code review
adityaiiitmk 40cbdba
fix: Correct docstring and import for MixedbreadAICrossEncoder
adityaiiitmk 7bac77a
Merge branch 'main' into aditya/mixedbreadai
adityaiiitmk 7b71897
Merge branch 'main' into aditya/mixedbreadai
adityaiiitmk 114af4d
Merge branch 'main' into aditya/mixedbreadai
mdrxy 029450c
Merge branch 'main' into aditya/mixedbreadai
mdrxy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
libs/community/langchain_community/cross_encoders/mixedbreadai.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| from typing import Any, Dict, List, Tuple | ||
|
|
||
| from pydantic import BaseModel, ConfigDict, Field | ||
|
|
||
| from langchain_community.cross_encoders.base import BaseCrossEncoder | ||
|
|
||
| DEFAULT_MODEL_NAME = "mixedbread-ai/mxbai-rerank-base-v2" | ||
|
|
||
|
|
||
| class MixedbreadAICrossEncoder(BaseModel, BaseCrossEncoder): | ||
| """Mixedbread cross encoder models. | ||
|
|
||
| Example: | ||
| .. code-block:: python | ||
|
|
||
| from langchain_community.cross_encoders import MixedbreadAICrossEncoder | ||
|
|
||
| model_name = "mixedbread-ai/mxbai-rerank-base-v2" | ||
| model_kwargs = {'top_k': 10} | ||
| mb = MixedbreadAICrossEncoder( | ||
| model_name=model_name, | ||
| model_kwargs=model_kwargs | ||
| ) | ||
| """ | ||
|
|
||
| client: Any = None #: :meta private: | ||
| model_name: str = DEFAULT_MODEL_NAME | ||
| """Model name to use.""" | ||
| model_kwargs: Dict[str, Any] = Field(default_factory=dict) | ||
| """Keyword arguments to pass to the model.""" | ||
| normalize_scores: bool = Field(default=True) | ||
| """Whether to normalize scores to [0, 1] range.""" | ||
|
|
||
| def __init__(self, **kwargs: Any): | ||
| """Initialize the mixbread reranker.""" | ||
adityaiiitmk marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| super().__init__(**kwargs) | ||
| try: | ||
| from mxbai_rerank import MxbaiRerankV2 | ||
| except ImportError as exc: | ||
| raise ImportError( | ||
| "Could not import mxbai_rerank python package. " | ||
| "Please install it with `pip install mxbai-rerank`." | ||
| ) from exc | ||
|
|
||
| self.client = MxbaiRerankV2(self.model_name, **self.model_kwargs) | ||
|
|
||
| model_config = ConfigDict(extra="forbid", protected_namespaces=()) | ||
|
|
||
| def _normalize_scores(self, scores: List[float]) -> List[float]: | ||
| """Normalize scores to [0, 1] range.""" | ||
| if not scores: | ||
| return scores | ||
|
|
||
| min_score = min(scores) | ||
| max_score = max(scores) | ||
|
|
||
| if max_score == min_score: | ||
| return [0.0] * len(scores) | ||
|
|
||
| return [(score - min_score) / (max_score - min_score) for score in scores] | ||
|
|
||
| def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: | ||
| """Compute similarity scores using a Mixbread transformer model. | ||
|
|
||
| Args: | ||
| text_pairs: The list of text pairs to score the similarity. | ||
| Each tuple should be (query, document). | ||
|
|
||
| Returns: | ||
| List of similarity/relevance scores between query-document pairs, | ||
| one float score for each input pair. | ||
| """ | ||
| if not text_pairs: | ||
| return [] | ||
|
|
||
| # Extract query and documents (assuming single query) | ||
| query = text_pairs[0][0] | ||
| documents = [pair[1] for pair in text_pairs] | ||
|
|
||
| # Single API call for all documents | ||
| results = self.client.rank( | ||
| query, | ||
| documents, | ||
| return_documents=False, | ||
| top_k=len(documents) | ||
| ) | ||
adityaiiitmk marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Create score mapping and preserve original order | ||
| score_map = {res.index: res.score for res in results} | ||
| scores = [score_map.get(i, 0.0) for i in range(len(documents))] | ||
|
|
||
| # Normalize scores if requested | ||
| if self.normalize_scores: | ||
| scores = self._normalize_scores(scores) | ||
|
|
||
| return scores | ||
49 changes: 49 additions & 0 deletions
49
libs/community/tests/integration_tests/cross_encoders/test_mixedbreadai.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| """Test mixedbreadai cross encoders.""" | ||
|
|
||
| from langchain_community.cross_encoders import MixedbreadAICrossEncoder | ||
|
|
||
|
|
||
| def _assert(encoder: MixedbreadAICrossEncoder) -> None: | ||
| query = "I love you" | ||
| texts = ["I love you", "I like you", "I don't like you", "I hate you"] | ||
| output = encoder.score([(query, text) for text in texts]) | ||
|
|
||
| # Check that we got scores for all texts | ||
| assert len(output) == len(texts) | ||
|
|
||
| # Check that scores are in descending order (most relevant first) | ||
| for i in range(len(texts) - 1): | ||
| assert output[i] > output[i + 1], f"Score at index {i} ({output[i]}) should be greater than score at index {i+1} ({output[i+1]})" | ||
|
|
||
|
|
||
| def test_mixedbreadai_cross_encoder() -> None: | ||
| """Test MixedbreadAICrossEncoder with default model.""" | ||
| encoder = MixedbreadAICrossEncoder() | ||
| _assert(encoder) | ||
|
|
||
|
|
||
| def test_mixedbreadai_cross_encoder_with_designated_model_name() -> None: | ||
| """Test MixedbreadAICrossEncoder with specific model.""" | ||
| encoder = MixedbreadAICrossEncoder(model_name="mixedbread-ai/mxbai-rerank-base-v1") | ||
| _assert(encoder) | ||
|
|
||
|
|
||
| def test_mixedbreadai_cross_encoder_without_normalization() -> None: | ||
| """Test MixedbreadAICrossEncoder with normalization disabled.""" | ||
| encoder = MixedbreadAICrossEncoder(normalize_scores=False) | ||
| _assert(encoder) | ||
|
|
||
|
|
||
|
|
||
| def test_mixedbreadai_cross_encoder_multilingual() -> None: | ||
| """Test MixedbreadAICrossEncoder with multilingual content.""" | ||
| encoder = MixedbreadAICrossEncoder() | ||
| query = "¿Cómo afecta la agricultura al clima?" | ||
| texts = [ | ||
| "El cambio climático provoca sequías e inundaciones que afectan los cultivos.", | ||
| "Climate change leads to droughts and floods, affecting crop yields.", | ||
| "Agriculture is impacted by rising temperatures and unpredictable weather." | ||
| ] | ||
| output = encoder.score([(query, text) for text in texts]) | ||
| assert len(output) == len(texts) | ||
|
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.