Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autorag/nodes/passagereranker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .voyageai import VoyageAIReranker
from .mixedbreadai import MixedbreadAIReranker
from .flashrank import FlashRankReranker
from .nvidia import NvidiaReranker
180 changes: 180 additions & 0 deletions autorag/nodes/passagereranker/nvidia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import os
from typing import List, Optional, Tuple

import aiohttp
import pandas as pd

from autorag.nodes.passagereranker.base import BasePassageReranker
from autorag.utils.util import get_event_loop, process_batch, result_to_dataframe


class NvidiaReranker(BasePassageReranker):
def __init__(self, project_dir: str, *args, **kwargs):
"""
Initialize Nvidia rerank node.

:param project_dir: The project directory path.
:param api_key: The API key for Nvidia rerank.
You can set it in the environment variable NVIDIA_API_KEY.
Or, you can directly set it on the config YAML file using this parameter.
Default is env variable "NVIDIA_API_KEY".
:param kwargs: Extra arguments that are not affected
"""
super().__init__(project_dir)
self.api_key = kwargs.pop("api_key", None)
self.api_key = self.api_key or os.getenv("NVIDIA_API_KEY", None)
if self.api_key is None:
raise KeyError(
"Please set the API key for Nvidia rerank in the environment variable NVIDIA_API_KEY "
"or directly set it on the config YAML file."
)
self.invoke_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking"
self.session = aiohttp.ClientSession(loop=get_event_loop())
self.session.headers.update(
{"Authorization": f"Bearer {self.api_key}", "Accept": "application/json"}
)

def __del__(self):
if hasattr(self, "session"):
if not self.session.closed:
loop = get_event_loop()
if loop.is_running():
loop.create_task(self.session.close())
else:
loop.run_until_complete(self.session.close())
del self.session
super().__del__()

@result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"])
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
queries, contents, scores, ids = self.cast_to_run(previous_result)
top_k = kwargs.pop("top_k")
batch = kwargs.pop("batch", 64)
model = kwargs.pop("model", "nvidia/rerank-qa-mistral-4b")
truncate = kwargs.pop("truncate", None)
return self._pure(queries, contents, scores, ids, top_k, batch, model, truncate)

def _pure(
self,
queries: List[str],
contents_list: List[List[str]],
scores_list: List[List[float]],
ids_list: List[List[str]],
top_k: int,
batch: int = 64,
model: str = "nvidia/rerank-qa-mistral-4b",
truncate: Optional[str] = None,
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
"""
Rerank a list of contents with Nvidia rerank models.

:param queries: The list of queries to use for reranking
:param contents_list: The list of lists of contents to rerank
:param scores_list: The list of lists of scores retrieved from the initial ranking
:param ids_list: The list of lists of ids retrieved from the initial ranking
:param top_k: The number of passages to be retrieved
:param batch: The number of queries to be processed in a batch
:param model: The model name for Nvidia rerank.
Default is "nvidia/rerank-qa-mistral-4b".
:param truncate: Optional truncation strategy for the API request
:return: Tuple of lists containing the reranked contents, ids, and scores
"""
if not (len(queries) == len(contents_list) == len(ids_list)):
raise AssertionError(
"NvidiaReranker input length mismatch. "
f"len(queries)={len(queries)}, len(contents_list)={len(contents_list)}, len(ids_list)={len(ids_list)}."
)

tasks = [
nvidia_rerank_pure(
self.session,
self.invoke_url,
model,
query,
document,
ids,
top_k,
truncate=truncate,
)
for query, document, ids in zip(queries, contents_list, ids_list)
]
loop = get_event_loop()
results = loop.run_until_complete(process_batch(tasks, batch_size=batch))
if len(results) != len(queries):
raise AssertionError(
"NVIDIA rerank returned unexpected number of results. "
f"expected={len(queries)}, got={len(results)}. "
"Failing fast to prevent downstream index mapping errors."
)

content_result, id_result, score_result = zip(*results)

return list(content_result), list(id_result), list(score_result)

async def nvidia_rerank_pure(
session: aiohttp.ClientSession,
invoke_url: str,
model: str,
query: str,
documents: List[str],
ids: List[str],
top_k: int,
truncate: Optional[str] = None,
) -> Tuple[List[str], List[str], List[float]]:
"""
Async function to call Nvidia Rerank API.

:param session: The aiohttp session to use for reranking
:param invoke_url: The Nvidia Rerank API endpoint
:param model: The model name for Nvidia rerank
:param query: The query to use for reranking
:param documents: The list of contents to rerank
:param ids: The list of ids corresponding to the documents
:param top_k: The number of passages to be retrieved
:param truncate: Optional truncation strategy for the API request
:return: Tuple of lists containing the reranked contents, ids, and scores
"""
payload = {
"model": model,
"query": {"text": query},
"passages": [{"text": doc} for doc in documents],
}
if truncate is not None:
payload["truncate"] = truncate

async with session.post(invoke_url, json=payload) as response:
if response.status != 200:
raise ValueError(
f"NVIDIA API Error: {response.status} - {await response.text()}"
)

response_body = await response.json()

rankings = response_body.get("rankings", [])
expected_len = len(documents)
if len(rankings) != expected_len:
raise AssertionError(
"NVIDIA rerank API returned unexpected rankings length. "
f"expected={expected_len}, got={len(rankings)}. "
"This can happen intermittently; failing fast to prevent index mapping errors."
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check that the input documents lengths and rankings which is returned from the NVIDIA API are same length.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed review! I've updated the code based on your comments. Please check it again.

def _score(item):
# According to the NVIDIA documentation, the output can be either
# probability scores or raw logits depending on the configuration.
# So we check both fields to support various model settings.
if item.get("logit") is not None:
return float(item["logit"])
if item.get("score") is not None:
return float(item["score"])
return 0.0

rankings.sort(key=_score, reverse=True)

top_rankings = rankings[:top_k]

reranked_contents = [documents[item["index"]] for item in top_rankings]
reranked_ids = [ids[item["index"]] for item in top_rankings]
reranked_scores = [_score(item) for item in top_rankings]

return reranked_contents, reranked_ids, reranked_scores
50 changes: 50 additions & 0 deletions docs/source/nodes/passage_reranker/nvidia_reranker.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
---
myst:
html_meta:
title: AutoRAG - NVIDIA Reranker
description: Learn about NVIDIA reranker module in AutoRAG
keywords: AutoRAG,RAG,Advanced RAG,Reranker,NVIDIA
---

# nvidia_reranker

The `NVIDIA reranker` module is a reranker from [NVIDIA](https://docs.api.nvidia.com/nim/reference/nvidia-rerank-qa-mistral-4b).

You can use one model from NVIDIA reranker, which is 'nvidia/rerank-qa-mistral-4b'.

## Before Usage

At first, you need to get the NVIDIA API key from [NVIDIA API](https://build.nvidia.com/nvidia/rerank-qa-mistral-4b).

Next, you can set your NVIDIA API key in the environment variable.

```bash
export NVIDIA_API_KEY=your_nvidia_api_key
```

Or, you can set your NVIDIA API key in the config.yaml file directly.

```yaml
- module_type: nvidia_reranker
api_key: your_nvidia_api_key
```

## **Module Parameters**

- **batch** : The size of a batch. It sends the batch size of queries to NVIDIA API at once. If it is too large, it can
cause some error. (default: 64)
You can adjust this value based on your API rate limits and performance requirements.
- **model** : The type of model you want to use for reranking. Default is "nvidia/rerank-qa-mistral-4b".
Currently, only "nvidia/rerank-qa-mistral-4b" is tested and verified.
- **truncate** : Optional truncation strategy for the API request. If not specified, no truncation is applied.
You can set this parameter to control how the API handles long inputs.
- **api_key** : The NVIDIA API key. If not provided, it will use the NVIDIA_API_KEY environment variable.

## **Example config.yaml**

```yaml
- module_type: nvidia_reranker
api_key: your_nvidia_api_key
batch: 32
model: nvidia/rerank-qa-mistral-4b
```
132 changes: 132 additions & 0 deletions tests/autorag/nodes/passagereranker/test_nvidia_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import aiohttp
import pytest
from aioresponses import aioresponses

from autorag.nodes.passagereranker import NvidiaReranker
from autorag.nodes.passagereranker.nvidia import nvidia_rerank_pure
from autorag.utils.util import get_event_loop
from tests.autorag.nodes.passagereranker.test_passage_reranker_base import (
queries_example,
contents_example,
scores_example,
ids_example,
base_reranker_test,
project_dir,
previous_result,
base_reranker_node_test,
)

NVIDIA_RERANK_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking"
MOCK_RESPONSE = {
"rankings": [
{"index": 1, "logit": 0.9},
{"index": 0, "logit": 0.2},
{"index": 2, "logit": 0.1},
]
}

@pytest.fixture
def nvidia_reranker_instance():
reranker = NvidiaReranker(project_dir=project_dir, api_key="test")
yield reranker
if hasattr(reranker, "session") and not reranker.session.closed:
loop = get_event_loop()
if loop.is_running():
loop.create_task(reranker.session.close())
else:
loop.run_until_complete(reranker.session.close())

@pytest.mark.asyncio()
async def test_nvidia_rerank_pure():
with aioresponses() as m:
m.post(NVIDIA_RERANK_URL, payload=MOCK_RESPONSE)

async with aiohttp.ClientSession() as session:
session.headers.update(
{"Authorization": "Bearer mock_api_key", "Accept": "application/json"}
)

documents = ["doc0", "doc1", "doc2"]
ids = ["id0", "id1", "id2"]

content_result, id_result, score_result = await nvidia_rerank_pure(
session,
NVIDIA_RERANK_URL,
"nvidia/rerank-qa-mistral-4b",
queries_example[0],
documents,
ids,
top_k=2,
)

assert len(content_result) == 2
assert len(id_result) == 2
assert len(score_result) == 2

assert all([res in documents for res in content_result])
assert all([res in ids for res in id_result])

assert score_result[0] >= score_result[1]


@pytest.mark.asyncio()
async def test_nvidia_rerank_pure_raises_when_rankings_length_mismatch():
with aioresponses() as m:
mock_response = {"rankings": [{"index": 0, "logit": 0.9}]}
m.post(NVIDIA_RERANK_URL, payload=mock_response)

async with aiohttp.ClientSession() as session:
session.headers.update(
{"Authorization": "Bearer mock_api_key", "Accept": "application/json"}
)
with pytest.raises(AssertionError):
await nvidia_rerank_pure(
session,
NVIDIA_RERANK_URL,
"nvidia/rerank-qa-mistral-4b",
queries_example[0],
["doc0", "doc1"],
["id0", "id1"],
top_k=2,
)

def test_nvidia_reranker(nvidia_reranker_instance):
with aioresponses() as m:
m.post(NVIDIA_RERANK_URL, payload=MOCK_RESPONSE, repeat=True)

top_k = 3
contents_result, id_result, score_result = nvidia_reranker_instance._pure(
queries_example, contents_example, scores_example, ids_example, top_k
)
base_reranker_test(contents_result, id_result, score_result, top_k)


def test_nvidia_reranker_batch_one(nvidia_reranker_instance):
with aioresponses() as m:
m.post(NVIDIA_RERANK_URL, payload=MOCK_RESPONSE, repeat=True)

top_k = 3
batch = 1
contents_result, id_result, score_result = nvidia_reranker_instance._pure(
queries_example,
contents_example,
scores_example,
ids_example,
top_k,
batch=batch,
)
base_reranker_test(contents_result, id_result, score_result, top_k)


def test_nvidia_reranker_node():
with aioresponses() as m:
m.post(NVIDIA_RERANK_URL, payload=MOCK_RESPONSE, repeat=True)

top_k = 1
result_df = NvidiaReranker.run_evaluator(
project_dir=project_dir,
previous_result=previous_result,
top_k=top_k,
api_key="test",
)
base_reranker_node_test(result_df, top_k)