Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ jobs:
CF_GATEWAY_ENDPOINT: ${{ secrets.CF_GATEWAY_ENDPOINT }}
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
NOMIC_API_KEY: ${{ secrets.NOMIC_API_KEY }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pip install chromadbx
- [SpaCy](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#spacy) embeddings
- [Together](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#together) embeddings.
- [Nomic](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#nomic) embeddings.
- [Reranking](https://github.com/amikos-tech/chromadbx/blob/main/docs/reranking.md) - rerank documents and query results using Cohere, OpenAI, or custom reranking functions.
- [Cohere](https://github.com/amikos-tech/chromadbx/blob/main/docs/reranking.md#cohere) - rerank documents and query results using Cohere.

## Usage

Expand Down
4 changes: 2 additions & 2 deletions chromadbx/reranking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ class RerankedQueryResult(TypedDict):
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[Distance]]]
included: Include
ranked_distances: Dict[RerankerID, List[Distance]]
ranked_distances: Dict[RerankerID, List[Distances]]


class RerankedDocuments(TypedDict):
documents: List[Documents]
ranked_distances: Dict[RerankerID, Distances]


RankedResults = Union[List[Documents], List[RerankedQueryResult]]
RankedResults = Union[RerankedDocuments, RerankedQueryResult]

D = TypeVar("D", bound=Rerankable, contravariant=True)
T = TypeVar("T", bound=RankedResults, covariant=True)
Expand Down
146 changes: 146 additions & 0 deletions chromadbx/reranking/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os
from typing import Any, Dict, Optional, List

from chromadbx.reranking import (
Queries,
RankedResults,
Rerankable,
RerankedDocuments,
RerankedQueryResult,
RerankerID,
RerankingFunction,
)
from chromadbx.reranking.utils import get_query_documents_tuples


class CohereReranker(RerankingFunction[Rerankable, RankedResults]):
def __init__(
self,
api_key: str,
model_name: Optional[str] = "rerank-v3.5",
*,
raw_scores: bool = False,
top_n: Optional[int] = None,
max_tokens_per_document: Optional[int] = 4096,
timeout: Optional[int] = 60,
max_retries: Optional[int] = 3,
additional_headers: Optional[Dict[str, Any]] = None,
):
"""
Initialize the CohereReranker.

Args:
api_key: The Cohere API key.
model_name: The Cohere model to use for reranking. Defaults to `rerank-v3.5`.
raw_scores: Whether to return the raw scores from the Cohere API. Defaults to `False`.
top_n: The number of results to return. Defaults to `None`.
max_tokens_per_document: The maximum number of tokens per document. Defaults to `4096`.
timeout: The timeout for the Cohere API request. Defaults to `60`.
max_retries: The maximum number of retries for the Cohere API request. Defaults to `3`.
additional_headers: Additional headers to include in the Cohere API request. Defaults to `None`.
"""
try:
import cohere
from cohere.core.request_options import RequestOptions
except ImportError:
raise ImportError(
"cohere is not installed. Please install it with `pip install cohere`"
)
if not api_key and not os.getenv("COHERE_API_KEY"):
raise ValueError(
"API key is required. Please set the COHERE_API_KEY environment variable or pass it directly."
)
if not model_name:
raise ValueError(
"Model name is required. Please set the model_name parameter or use the default value."
)
self._client = cohere.ClientV2(api_key)
self._model_name = model_name
self._top_n = top_n
self._raw_scores = raw_scores
self._max_tokens_per_document = max_tokens_per_document
self._request_options = RequestOptions(
timeout_in_seconds=timeout,
max_retries=max_retries,
additional_headers=additional_headers,
)

def id(self) -> RerankerID:
return RerankerID("cohere")

def _combine_reranked_results(
self, results_list: List["cohere.v2.types.V2RerankResponse"], rerankables: Rerankable # type: ignore # noqa: F821
) -> RankedResults:
all_ordered_scores = []

for results in results_list:
if self._raw_scores:
ordered_scores = [
r.relevance_score
for r in sorted(results.results, key=lambda x: x.index) # type: ignore
]
else: # by default we calculate the distance to make results comparable with Chroma distance
ordered_scores = [
1 - r.relevance_score
for r in sorted(results.results, key=lambda x: x.index) # type: ignore
]
all_ordered_scores.append(ordered_scores)

if isinstance(rerankables, dict):
combined_ordered_scores = [
score for sublist in all_ordered_scores for score in sublist
]
if len(rerankables["ids"]) != len(combined_ordered_scores):
combined_ordered_scores = combined_ordered_scores + [None] * (
len(rerankables["ids"]) - len(combined_ordered_scores)
)
return RerankedQueryResult(
ids=rerankables["ids"],
embeddings=rerankables["embeddings"]
if "embeddings" in rerankables
else None,
documents=rerankables["documents"]
if "documents" in rerankables
else None,
uris=rerankables["uris"] if "uris" in rerankables else None,
data=rerankables["data"] if "data" in rerankables else None,
metadatas=rerankables["metadatas"]
if "metadatas" in rerankables
else None,
distances=rerankables["distances"]
if "distances" in rerankables
else None,
included=rerankables["included"] if "included" in rerankables else None,
ranked_distances={self.id(): combined_ordered_scores},
)
elif isinstance(rerankables, list):
if len(results_list) > 1:
raise ValueError("Cannot rerank documents with multiple results")
combined_ordered_scores = [
score for sublist in all_ordered_scores for score in sublist
]
if len(rerankables) != len(combined_ordered_scores):
combined_ordered_scores = combined_ordered_scores + [None] * (
len(rerankables) - len(combined_ordered_scores)
)
return RerankedDocuments(
documents=rerankables,
ranked_distances={self.id(): combined_ordered_scores},
)
else:
raise ValueError("Invalid rerankables type")

def __call__(self, queries: Queries, rerankables: Rerankable) -> RankedResults:
query_documents_tuples = get_query_documents_tuples(queries, rerankables)
results = []
for query, documents in query_documents_tuples:
response = self._client.rerank(
model=self._model_name,
query=query,
documents=documents,
top_n=self._top_n or len(documents),
max_tokens_per_doc=self._max_tokens_per_document,
request_options=self._request_options,
)
results.append(response)
return self._combine_reranked_results(results, rerankables)
79 changes: 79 additions & 0 deletions docs/reranking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Reranking

Reranking is a process of reordering a list of items based on their relevance to a query. This project supports reranking of documents and query results.

```python
from chromadbx.reranking.some_reranker import SomeReranker
import chromadb
some_reranker = SomeReranker()

client = chromadb.Client()

collection = client.get_collection("documents")

results = collection.query(
query_texts=["What is the capital of the United States?"],
n_results=10,
)

reranked_results = some_reranker(results)

print("Documents:", reranked_results["documents"][0])
print("Distances:", reranked_results["distances"][0])
print("Reranked distances:", reranked_results["ranked_distances"][some_reranker.id()][0])
```

> [!NOTE]
> It is our intent that all officially supported reranking functions shall return distances instead of scores to be consistent with the core Chroma project. However, this is not a hard requirement and you should check the documentation for each reranking function you plan to use.

The following reranking functions are supported:

| Reranking Function | Official Docs |
| ------------------ | ------------- |
| [Cohere](#cohere) | [docs](https://docs.cohere.com/docs/rerank-2) |

## Cohere

Cohere reranking function offers a convinient wrapper around the Cohere API to rerank documents and query results. For more information on Cohere reranking, visit the official [docs](https://docs.cohere.com/docs/rerank-2) or [API docs](https://docs.cohere.com/reference/rerank).

You need to install the `cohere` package to use this reranking function.


```bash
pip install cohere # or poetry add cohere
```

Before using the reranking function, you need to obtain [Cohere API](https://dashboard.cohere.com/api-keys) key and set the `COHERE_API_KEY` environment variable.

> [!TIP]
> By default, the reranking function will return distances. If you need to get the raw scores, set the `raw_scores` parameter to `True`.

```python
import os
import chromadb
from chromadbx.reranking import CohereReranker

cohere = CohereReranker(api_key=os.getenv("COHERE_API_KEY"))

client = chromadb.Client()

collection = client.get_collection("documents")

results = collection.query(
query_texts=["What is the capital of the United States?"],
n_results=10,
)

reranked_results = cohere(results)
```

Available options:

- `api_key`: The Cohere API key.
- `model_name`: The Cohere model to use for reranking. Defaults to `rerank-v3.5`.
- `raw_scores`: Whether to return the raw scores from the Cohere API. Defaults to `False`.
- `top_n`: The number of results to return. Defaults to `None`.
- `max_tokens_per_document`: The maximum number of tokens per document. Defaults to `4096`.
- `timeout`: The timeout for the Cohere API request. Defaults to `60`.
- `max_retries`: The maximum number of retries for the Cohere API request. Defaults to `3`.
- `additional_headers`: Additional headers to include in the Cohere API request. Defaults to `None`.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ llama-embedder = "^0.0.7"
mistralai = "^1.1.0"
spacy = "^3.8.4"
together = "^1.3.11"
cohere = "^5.13.8"

[tool.poetry.extras]
ids = ["ulid-py", "nanoid"]
embeddings = ["llama-embedder", "onnxruntime", "huggingface_hub", "mistralai", "spacy", "together", "vertexai"]
reranking = ["cohere"]
core = ["chromadb"]

[build-system]
Expand Down
2 changes: 1 addition & 1 deletion test/embeddings/test_nomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from chromadbx.embeddings.nomic import NomicEmbeddingFunction

httpx = pytest.importorskip("httpx", reason="nomic not installed")
httpx = pytest.importorskip("httpx", reason="httpx not installed")


@pytest.mark.skipif(
Expand Down
70 changes: 70 additions & 0 deletions test/reranking/test_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
from typing import cast

from chromadb import QueryResult
import pytest
from chromadbx.reranking import RerankedDocuments, RerankedQueryResult
from chromadbx.reranking.cohere import CohereReranker


from unittest.mock import MagicMock

_cohere = pytest.importorskip("cohere", reason="cohere not installed")


def test_cohere_mock_rerank_documents() -> None:
mock_client = MagicMock()
mock_client.rerank.return_value = MagicMock(results=[])

cohere = CohereReranker(api_key="test")
cohere._client = mock_client

queries = "What is the capital of the United States?"
rerankables = ["Washington, D.C.", "New York", "Los Angeles"]

cohere(queries, rerankables)
mock_client.rerank.assert_called_once_with(
model="rerank-v3.5",
query=queries,
documents=rerankables,
top_n=len(rerankables),
max_tokens_per_doc=4096,
request_options=cohere._request_options,
)


@pytest.mark.skipif(
os.getenv("COHERE_API_KEY") is None,
reason="COHERE_API_KEY environment variable is not set",
)
def test_cohere_rerank_documents() -> None:
cohere = CohereReranker(api_key=os.getenv("COHERE_API_KEY", ""))
queries = "What is the capital of the United States?"
rerankables = ["Washington, D.C.", "New York", "Los Angeles"]
result = cast(RerankedDocuments, cohere(queries, rerankables))
assert "ranked_distances" in result
assert len(result["ranked_distances"][cohere.id()]) == len(rerankables)
assert result["ranked_distances"][cohere.id()].index(
min(result["ranked_distances"][cohere.id()])
) == rerankables.index("Washington, D.C.")


@pytest.mark.skipif(
os.getenv("COHERE_API_KEY") is None,
reason="COHERE_API_KEY environment variable is not set",
)
def test_cohere_rerank_documents_with_query_result() -> None:
cohere = CohereReranker(api_key=os.getenv("COHERE_API_KEY", ""))
queries = ["What is the capital of the United States?"]
rerankables = QueryResult(
documents=[["Washington, D.C.", "New York", "Los Angeles"]],
metadatas=[[{"source": "test"}, {"source": "test"}, {"source": "test"}]],
embeddings=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
ids=[["id1", "id2", "id3"]],
)
result = cast(RerankedQueryResult, cohere(queries, rerankables))
assert "ranked_distances" in result
assert len(result["ranked_distances"][cohere.id()]) == len(rerankables["ids"][0])
assert result["ranked_distances"][cohere.id()].index(
min(result["ranked_distances"][cohere.id()])
) == rerankables["ids"][0].index("id1")
Loading