-
-
Notifications
You must be signed in to change notification settings - Fork 373
[Feature Request] New NVIDA reranker module #1199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
hypoxisaurea
wants to merge
3
commits into
Marker-Inc-Korea:main
Choose a base branch
from
hypoxisaurea:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,180 @@ | ||
| import os | ||
| from typing import List, Optional, Tuple | ||
|
|
||
| import aiohttp | ||
| import pandas as pd | ||
|
|
||
| from autorag.nodes.passagereranker.base import BasePassageReranker | ||
| from autorag.utils.util import get_event_loop, process_batch, result_to_dataframe | ||
|
|
||
|
|
||
| class NvidiaReranker(BasePassageReranker): | ||
| def __init__(self, project_dir: str, *args, **kwargs): | ||
| """ | ||
| Initialize Nvidia rerank node. | ||
|
|
||
| :param project_dir: The project directory path. | ||
| :param api_key: The API key for Nvidia rerank. | ||
| You can set it in the environment variable NVIDIA_API_KEY. | ||
| Or, you can directly set it on the config YAML file using this parameter. | ||
| Default is env variable "NVIDIA_API_KEY". | ||
| :param kwargs: Extra arguments that are not affected | ||
| """ | ||
| super().__init__(project_dir) | ||
| self.api_key = kwargs.pop("api_key", None) | ||
| self.api_key = self.api_key or os.getenv("NVIDIA_API_KEY", None) | ||
| if self.api_key is None: | ||
| raise KeyError( | ||
| "Please set the API key for Nvidia rerank in the environment variable NVIDIA_API_KEY " | ||
| "or directly set it on the config YAML file." | ||
| ) | ||
| self.invoke_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking" | ||
| self.session = aiohttp.ClientSession(loop=get_event_loop()) | ||
| self.session.headers.update( | ||
| {"Authorization": f"Bearer {self.api_key}", "Accept": "application/json"} | ||
| ) | ||
|
|
||
| def __del__(self): | ||
| if hasattr(self, "session"): | ||
| if not self.session.closed: | ||
| loop = get_event_loop() | ||
| if loop.is_running(): | ||
| loop.create_task(self.session.close()) | ||
| else: | ||
| loop.run_until_complete(self.session.close()) | ||
| del self.session | ||
| super().__del__() | ||
|
|
||
| @result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"]) | ||
| def pure(self, previous_result: pd.DataFrame, *args, **kwargs): | ||
| queries, contents, scores, ids = self.cast_to_run(previous_result) | ||
| top_k = kwargs.pop("top_k") | ||
| batch = kwargs.pop("batch", 64) | ||
| model = kwargs.pop("model", "nvidia/rerank-qa-mistral-4b") | ||
| truncate = kwargs.pop("truncate", None) | ||
| return self._pure(queries, contents, scores, ids, top_k, batch, model, truncate) | ||
|
|
||
| def _pure( | ||
| self, | ||
| queries: List[str], | ||
| contents_list: List[List[str]], | ||
| scores_list: List[List[float]], | ||
| ids_list: List[List[str]], | ||
| top_k: int, | ||
| batch: int = 64, | ||
| model: str = "nvidia/rerank-qa-mistral-4b", | ||
| truncate: Optional[str] = None, | ||
| ) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: | ||
| """ | ||
| Rerank a list of contents with Nvidia rerank models. | ||
|
|
||
| :param queries: The list of queries to use for reranking | ||
| :param contents_list: The list of lists of contents to rerank | ||
| :param scores_list: The list of lists of scores retrieved from the initial ranking | ||
| :param ids_list: The list of lists of ids retrieved from the initial ranking | ||
| :param top_k: The number of passages to be retrieved | ||
| :param batch: The number of queries to be processed in a batch | ||
| :param model: The model name for Nvidia rerank. | ||
| Default is "nvidia/rerank-qa-mistral-4b". | ||
| :param truncate: Optional truncation strategy for the API request | ||
| :return: Tuple of lists containing the reranked contents, ids, and scores | ||
| """ | ||
| if not (len(queries) == len(contents_list) == len(ids_list)): | ||
| raise AssertionError( | ||
| "NvidiaReranker input length mismatch. " | ||
| f"len(queries)={len(queries)}, len(contents_list)={len(contents_list)}, len(ids_list)={len(ids_list)}." | ||
| ) | ||
|
|
||
| tasks = [ | ||
| nvidia_rerank_pure( | ||
| self.session, | ||
| self.invoke_url, | ||
| model, | ||
| query, | ||
| document, | ||
| ids, | ||
| top_k, | ||
| truncate=truncate, | ||
| ) | ||
| for query, document, ids in zip(queries, contents_list, ids_list) | ||
| ] | ||
| loop = get_event_loop() | ||
| results = loop.run_until_complete(process_batch(tasks, batch_size=batch)) | ||
| if len(results) != len(queries): | ||
| raise AssertionError( | ||
| "NVIDIA rerank returned unexpected number of results. " | ||
| f"expected={len(queries)}, got={len(results)}. " | ||
| "Failing fast to prevent downstream index mapping errors." | ||
| ) | ||
|
|
||
| content_result, id_result, score_result = zip(*results) | ||
|
|
||
| return list(content_result), list(id_result), list(score_result) | ||
|
|
||
| async def nvidia_rerank_pure( | ||
| session: aiohttp.ClientSession, | ||
| invoke_url: str, | ||
| model: str, | ||
| query: str, | ||
| documents: List[str], | ||
| ids: List[str], | ||
| top_k: int, | ||
| truncate: Optional[str] = None, | ||
| ) -> Tuple[List[str], List[str], List[float]]: | ||
| """ | ||
| Async function to call Nvidia Rerank API. | ||
|
|
||
| :param session: The aiohttp session to use for reranking | ||
| :param invoke_url: The Nvidia Rerank API endpoint | ||
| :param model: The model name for Nvidia rerank | ||
| :param query: The query to use for reranking | ||
| :param documents: The list of contents to rerank | ||
| :param ids: The list of ids corresponding to the documents | ||
| :param top_k: The number of passages to be retrieved | ||
| :param truncate: Optional truncation strategy for the API request | ||
| :return: Tuple of lists containing the reranked contents, ids, and scores | ||
| """ | ||
| payload = { | ||
| "model": model, | ||
| "query": {"text": query}, | ||
| "passages": [{"text": doc} for doc in documents], | ||
| } | ||
| if truncate is not None: | ||
| payload["truncate"] = truncate | ||
|
|
||
| async with session.post(invoke_url, json=payload) as response: | ||
| if response.status != 200: | ||
| raise ValueError( | ||
| f"NVIDIA API Error: {response.status} - {await response.text()}" | ||
| ) | ||
|
|
||
| response_body = await response.json() | ||
|
|
||
| rankings = response_body.get("rankings", []) | ||
| expected_len = len(documents) | ||
| if len(rankings) != expected_len: | ||
| raise AssertionError( | ||
| "NVIDIA rerank API returned unexpected rankings length. " | ||
| f"expected={expected_len}, got={len(rankings)}. " | ||
| "This can happen intermittently; failing fast to prevent index mapping errors." | ||
| ) | ||
|
|
||
| def _score(item): | ||
| # According to the NVIDIA documentation, the output can be either | ||
| # probability scores or raw logits depending on the configuration. | ||
| # So we check both fields to support various model settings. | ||
| if item.get("logit") is not None: | ||
| return float(item["logit"]) | ||
| if item.get("score") is not None: | ||
| return float(item["score"]) | ||
| return 0.0 | ||
|
|
||
| rankings.sort(key=_score, reverse=True) | ||
|
|
||
| top_rankings = rankings[:top_k] | ||
|
|
||
| reranked_contents = [documents[item["index"]] for item in top_rankings] | ||
| reranked_ids = [ids[item["index"]] for item in top_rankings] | ||
| reranked_scores = [_score(item) for item in top_rankings] | ||
|
|
||
| return reranked_contents, reranked_ids, reranked_scores | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| --- | ||
| myst: | ||
| html_meta: | ||
| title: AutoRAG - NVIDIA Reranker | ||
| description: Learn about NVIDIA reranker module in AutoRAG | ||
| keywords: AutoRAG,RAG,Advanced RAG,Reranker,NVIDIA | ||
| --- | ||
|
|
||
| # nvidia_reranker | ||
|
|
||
| The `NVIDIA reranker` module is a reranker from [NVIDIA](https://docs.api.nvidia.com/nim/reference/nvidia-rerank-qa-mistral-4b). | ||
|
|
||
| You can use one model from NVIDIA reranker, which is 'nvidia/rerank-qa-mistral-4b'. | ||
|
|
||
| ## Before Usage | ||
|
|
||
| At first, you need to get the NVIDIA API key from [NVIDIA API](https://build.nvidia.com/nvidia/rerank-qa-mistral-4b). | ||
|
|
||
| Next, you can set your NVIDIA API key in the environment variable. | ||
|
|
||
| ```bash | ||
| export NVIDIA_API_KEY=your_nvidia_api_key | ||
| ``` | ||
|
|
||
| Or, you can set your NVIDIA API key in the config.yaml file directly. | ||
|
|
||
| ```yaml | ||
| - module_type: nvidia_reranker | ||
| api_key: your_nvidia_api_key | ||
| ``` | ||
|
|
||
| ## **Module Parameters** | ||
|
|
||
| - **batch** : The size of a batch. It sends the batch size of queries to NVIDIA API at once. If it is too large, it can | ||
| cause some error. (default: 64) | ||
| You can adjust this value based on your API rate limits and performance requirements. | ||
| - **model** : The type of model you want to use for reranking. Default is "nvidia/rerank-qa-mistral-4b". | ||
| Currently, only "nvidia/rerank-qa-mistral-4b" is tested and verified. | ||
| - **truncate** : Optional truncation strategy for the API request. If not specified, no truncation is applied. | ||
| You can set this parameter to control how the API handles long inputs. | ||
| - **api_key** : The NVIDIA API key. If not provided, it will use the NVIDIA_API_KEY environment variable. | ||
|
|
||
| ## **Example config.yaml** | ||
|
|
||
| ```yaml | ||
| - module_type: nvidia_reranker | ||
| api_key: your_nvidia_api_key | ||
| batch: 32 | ||
| model: nvidia/rerank-qa-mistral-4b | ||
| ``` |
132 changes: 132 additions & 0 deletions
132
tests/autorag/nodes/passagereranker/test_nvidia_reranker.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| import aiohttp | ||
| import pytest | ||
| from aioresponses import aioresponses | ||
|
|
||
| from autorag.nodes.passagereranker import NvidiaReranker | ||
| from autorag.nodes.passagereranker.nvidia import nvidia_rerank_pure | ||
| from autorag.utils.util import get_event_loop | ||
| from tests.autorag.nodes.passagereranker.test_passage_reranker_base import ( | ||
| queries_example, | ||
| contents_example, | ||
| scores_example, | ||
| ids_example, | ||
| base_reranker_test, | ||
| project_dir, | ||
| previous_result, | ||
| base_reranker_node_test, | ||
| ) | ||
|
|
||
| NVIDIA_RERANK_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking" | ||
| MOCK_RESPONSE = { | ||
| "rankings": [ | ||
| {"index": 1, "logit": 0.9}, | ||
| {"index": 0, "logit": 0.2}, | ||
| {"index": 2, "logit": 0.1}, | ||
| ] | ||
| } | ||
|
|
||
| @pytest.fixture | ||
| def nvidia_reranker_instance(): | ||
| reranker = NvidiaReranker(project_dir=project_dir, api_key="test") | ||
| yield reranker | ||
| if hasattr(reranker, "session") and not reranker.session.closed: | ||
| loop = get_event_loop() | ||
| if loop.is_running(): | ||
| loop.create_task(reranker.session.close()) | ||
| else: | ||
| loop.run_until_complete(reranker.session.close()) | ||
|
|
||
| @pytest.mark.asyncio() | ||
| async def test_nvidia_rerank_pure(): | ||
| with aioresponses() as m: | ||
| m.post(NVIDIA_RERANK_URL, payload=MOCK_RESPONSE) | ||
|
|
||
| async with aiohttp.ClientSession() as session: | ||
| session.headers.update( | ||
| {"Authorization": "Bearer mock_api_key", "Accept": "application/json"} | ||
| ) | ||
|
|
||
| documents = ["doc0", "doc1", "doc2"] | ||
| ids = ["id0", "id1", "id2"] | ||
|
|
||
| content_result, id_result, score_result = await nvidia_rerank_pure( | ||
| session, | ||
| NVIDIA_RERANK_URL, | ||
| "nvidia/rerank-qa-mistral-4b", | ||
| queries_example[0], | ||
| documents, | ||
| ids, | ||
| top_k=2, | ||
| ) | ||
|
|
||
| assert len(content_result) == 2 | ||
| assert len(id_result) == 2 | ||
| assert len(score_result) == 2 | ||
|
|
||
| assert all([res in documents for res in content_result]) | ||
| assert all([res in ids for res in id_result]) | ||
|
|
||
| assert score_result[0] >= score_result[1] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio() | ||
| async def test_nvidia_rerank_pure_raises_when_rankings_length_mismatch(): | ||
| with aioresponses() as m: | ||
| mock_response = {"rankings": [{"index": 0, "logit": 0.9}]} | ||
| m.post(NVIDIA_RERANK_URL, payload=mock_response) | ||
|
|
||
| async with aiohttp.ClientSession() as session: | ||
| session.headers.update( | ||
| {"Authorization": "Bearer mock_api_key", "Accept": "application/json"} | ||
| ) | ||
| with pytest.raises(AssertionError): | ||
| await nvidia_rerank_pure( | ||
| session, | ||
| NVIDIA_RERANK_URL, | ||
| "nvidia/rerank-qa-mistral-4b", | ||
| queries_example[0], | ||
| ["doc0", "doc1"], | ||
| ["id0", "id1"], | ||
| top_k=2, | ||
| ) | ||
|
|
||
| def test_nvidia_reranker(nvidia_reranker_instance): | ||
| with aioresponses() as m: | ||
| m.post(NVIDIA_RERANK_URL, payload=MOCK_RESPONSE, repeat=True) | ||
|
|
||
| top_k = 3 | ||
| contents_result, id_result, score_result = nvidia_reranker_instance._pure( | ||
| queries_example, contents_example, scores_example, ids_example, top_k | ||
| ) | ||
| base_reranker_test(contents_result, id_result, score_result, top_k) | ||
|
|
||
|
|
||
| def test_nvidia_reranker_batch_one(nvidia_reranker_instance): | ||
| with aioresponses() as m: | ||
| m.post(NVIDIA_RERANK_URL, payload=MOCK_RESPONSE, repeat=True) | ||
|
|
||
| top_k = 3 | ||
| batch = 1 | ||
| contents_result, id_result, score_result = nvidia_reranker_instance._pure( | ||
| queries_example, | ||
| contents_example, | ||
| scores_example, | ||
| ids_example, | ||
| top_k, | ||
| batch=batch, | ||
| ) | ||
| base_reranker_test(contents_result, id_result, score_result, top_k) | ||
|
|
||
|
|
||
| def test_nvidia_reranker_node(): | ||
| with aioresponses() as m: | ||
| m.post(NVIDIA_RERANK_URL, payload=MOCK_RESPONSE, repeat=True) | ||
|
|
||
| top_k = 1 | ||
| result_df = NvidiaReranker.run_evaluator( | ||
| project_dir=project_dir, | ||
| previous_result=previous_result, | ||
| top_k=top_k, | ||
| api_key="test", | ||
| ) | ||
| base_reranker_node_test(result_df, top_k) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should check that the input documents lengths and
rankingswhich is returned from the NVIDIA API are same length.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed review! I've updated the code based on your comments. Please check it again.