Skip to content

Commit c6e9bff

Browse files
committed
feat: Tests and docs
1 parent 1639a37 commit c6e9bff

File tree

7 files changed

+182
-66
lines changed

7 files changed

+182
-66
lines changed

.github/workflows/test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,4 @@ jobs:
8787
CF_GATEWAY_ENDPOINT: ${{ secrets.CF_GATEWAY_ENDPOINT }}
8888
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
8989
NOMIC_API_KEY: ${{ secrets.NOMIC_API_KEY }}
90+
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ pip install chromadbx
1919
- [SpaCy](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#spacy) embeddings
2020
- [Together](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#together) embeddings.
2121
- [Nomic](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#nomic) embeddings.
22+
- [Reranking](https://github.com/amikos-tech/chromadbx/blob/main/docs/reranking.md) - rerank documents and query results using Cohere, OpenAI, or custom reranking functions.
23+
- [Cohere](https://github.com/amikos-tech/chromadbx/blob/main/docs/reranking.md#cohere) - rerank documents and query results using Cohere.
2224

2325
## Usage
2426

chromadbx/reranking/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ class RerankedQueryResult(TypedDict):
4848
metadatas: Optional[List[List[Metadata]]]
4949
distances: Optional[List[List[Distance]]]
5050
included: Include
51-
ranked_distances: Dict[RerankerID, List[Distance]]
51+
ranked_distances: Dict[RerankerID, List[Distances]]
5252

5353

5454
class RerankedDocuments(TypedDict):
5555
documents: List[Documents]
5656
ranked_distances: Dict[RerankerID, Distances]
5757

5858

59-
RankedResults = Union[RerankedDocuments, List[RerankedQueryResult]]
59+
RankedResults = Union[RerankedDocuments, RerankedQueryResult]
6060

6161
D = TypeVar("D", bound=Rerankable, contravariant=True)
6262
T = TypeVar("T", bound=RankedResults, covariant=True)

chromadbx/reranking/cohere.py

Lines changed: 66 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict, Optional, List
33

44
from chromadbx.reranking import (
55
Queries,
@@ -26,6 +26,19 @@ def __init__(
2626
max_retries: Optional[int] = 3,
2727
additional_headers: Optional[Dict[str, Any]] = None,
2828
):
29+
"""
30+
Initialize the CohereReranker.
31+
32+
Args:
33+
api_key: The Cohere API key.
34+
model_name: The Cohere model to use for reranking. Defaults to `rerank-v3.5`.
35+
raw_scores: Whether to return the raw scores from the Cohere API. Defaults to `False`.
36+
top_n: The number of results to return. Defaults to `None`.
37+
max_tokens_per_document: The maximum number of tokens per document. Defaults to `4096`.
38+
timeout: The timeout for the Cohere API request. Defaults to `60`.
39+
max_retries: The maximum number of retries for the Cohere API request. Defaults to `3`.
40+
additional_headers: Additional headers to include in the Cohere API request. Defaults to `None`.
41+
"""
2942
try:
3043
import cohere
3144
from cohere.core.request_options import RequestOptions
@@ -56,79 +69,70 @@ def id(self) -> RerankerID:
5669
return RerankerID("cohere")
5770

5871
def _combine_reranked_results(
59-
self, results: "cohere.v2.types.V2RerankResponse", rerankables: Rerankable # type: ignore # noqa: F821
72+
self, results_list: List["cohere.v2.types.V2RerankResponse"], rerankables: Rerankable # type: ignore # noqa: F821
6073
) -> RankedResults:
61-
"""
62-
{
63-
"results": [
64-
{
65-
"index": 3,
66-
"relevance_score": 0.999071
67-
},
68-
{
69-
"index": 4,
70-
"relevance_score": 0.7867867
71-
},
72-
{
73-
"index": 0,
74-
"relevance_score": 0.32713068
75-
}
76-
],
77-
"id": "07734bd2-2473-4f07-94e1-0d9f0e6843cf",
78-
"meta": {
79-
"api_version": {
80-
"version": "2",
81-
"is_experimental": false
82-
},
83-
"billed_units": {
84-
"search_units": 1
85-
}
86-
}
87-
}
88-
"""
74+
all_ordered_scores = []
75+
76+
for results in results_list:
77+
if self._raw_scores:
78+
ordered_scores = [
79+
r.relevance_score
80+
for r in sorted(results.results, key=lambda x: x.index) # type: ignore
81+
]
82+
else: # by default we calculate the distance to make results comparable with Chroma distance
83+
ordered_scores = [
84+
1 - r.relevance_score
85+
for r in sorted(results.results, key=lambda x: x.index) # type: ignore
86+
]
87+
all_ordered_scores.append(ordered_scores)
8988

90-
if self._raw_scores:
91-
ordered_scores = [
92-
r.relevance_score
93-
for r in sorted(results.results, key=lambda x: x.index) # type: ignore
94-
]
95-
else: # by default we calculate the distance to make results comparable with Chroma distance
96-
ordered_scores = [
97-
1 - r.relevance_score
98-
for r in sorted(results.results, key=lambda x: x.index) # type: ignore
99-
]
10089
if isinstance(rerankables, dict):
101-
if len(rerankables["ids"]) != len(ordered_scores):
102-
ordered_scores = ordered_scores + [None] * (
103-
len(rerankables["ids"]) - len(ordered_scores)
104-
)
105-
return [
106-
RerankedQueryResult(
107-
ids=rerankables["ids"],
108-
embeddings=rerankables["embeddings"],
109-
documents=rerankables["documents"],
110-
uris=rerankables["uris"],
111-
data=rerankables["data"],
112-
metadatas=rerankables["metadatas"],
113-
distances=rerankables["distances"],
114-
included=rerankables["included"],
115-
ranked_distances={self.id(): ordered_scores},
116-
)
90+
combined_ordered_scores = [
91+
score for sublist in all_ordered_scores for score in sublist
11792
]
93+
if len(rerankables["ids"]) != len(combined_ordered_scores):
94+
combined_ordered_scores = combined_ordered_scores + [None] * (
95+
len(rerankables["ids"]) - len(combined_ordered_scores)
96+
)
97+
return RerankedQueryResult(
98+
ids=rerankables["ids"],
99+
embeddings=rerankables["embeddings"]
100+
if "embeddings" in rerankables
101+
else None,
102+
documents=rerankables["documents"]
103+
if "documents" in rerankables
104+
else None,
105+
uris=rerankables["uris"] if "uris" in rerankables else None,
106+
data=rerankables["data"] if "data" in rerankables else None,
107+
metadatas=rerankables["metadatas"]
108+
if "metadatas" in rerankables
109+
else None,
110+
distances=rerankables["distances"]
111+
if "distances" in rerankables
112+
else None,
113+
included=rerankables["included"] if "included" in rerankables else None,
114+
ranked_distances={self.id(): combined_ordered_scores},
115+
)
118116
elif isinstance(rerankables, list):
119-
if len(rerankables) != len(ordered_scores):
120-
ordered_scores = ordered_scores + [None] * (
121-
len(rerankables) - len(ordered_scores)
117+
if len(results_list) > 1:
118+
raise ValueError("Cannot rerank documents with multiple results")
119+
combined_ordered_scores = [
120+
score for sublist in all_ordered_scores for score in sublist
121+
]
122+
if len(rerankables) != len(combined_ordered_scores):
123+
combined_ordered_scores = combined_ordered_scores + [None] * (
124+
len(rerankables) - len(combined_ordered_scores)
122125
)
123126
return RerankedDocuments(
124127
documents=rerankables,
125-
ranked_distances={self.id(): ordered_scores},
128+
ranked_distances={self.id(): combined_ordered_scores},
126129
)
127130
else:
128131
raise ValueError("Invalid rerankables type")
129132

130133
def __call__(self, queries: Queries, rerankables: Rerankable) -> RankedResults:
131134
query_documents_tuples = get_query_documents_tuples(queries, rerankables)
135+
results = []
132136
for query, documents in query_documents_tuples:
133137
response = self._client.rerank(
134138
model=self._model_name,
@@ -138,5 +142,5 @@ def __call__(self, queries: Queries, rerankables: Rerankable) -> RankedResults:
138142
max_tokens_per_doc=self._max_tokens_per_document,
139143
request_options=self._request_options,
140144
)
141-
print(response)
142-
return self._combine_reranked_results(response, rerankables)
145+
results.append(response)
146+
return self._combine_reranked_results(results, rerankables)

docs/reranking.md

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Reranking
2+
3+
Reranking is a process of reordering a list of items based on their relevance to a query. This project supports reranking of documents and query results.
4+
5+
```python
6+
from chromadbx.reranking.some_reranker import SomeReranker
7+
import chromadb
8+
some_reranker = SomeReranker()
9+
10+
client = chromadb.Client()
11+
12+
collection = client.get_collection("documents")
13+
14+
results = collection.query(
15+
query_texts=["What is the capital of the United States?"],
16+
n_results=10,
17+
)
18+
19+
reranked_results = some_reranker(results)
20+
21+
print("Documents:", reranked_results["documents"][0])
22+
print("Distances:", reranked_results["distances"][0])
23+
print("Reranked distances:", reranked_results["ranked_distances"][some_reranker.id()][0])
24+
```
25+
26+
> [!NOTE]
27+
> It is our intent that all officially supported reranking functions shall return distances instead of scores to be consistent with the core Chroma project. However, this is not a hard requirement and you should check the documentation for each reranking function you plan to use.
28+
29+
The following reranking functions are supported:
30+
31+
| Reranking Function | Official Docs |
32+
| ------------------ | ------------- |
33+
| [Cohere](#cohere) | [docs](https://docs.cohere.com/docs/rerank-2) |
34+
35+
## Cohere
36+
37+
Cohere reranking function offers a convinient wrapper around the Cohere API to rerank documents and query results. For more information on Cohere reranking, visit the official [docs](https://docs.cohere.com/docs/rerank-2) or [API docs](https://docs.cohere.com/reference/rerank).
38+
39+
You need to install the `cohere` package to use this reranking function.
40+
41+
42+
```bash
43+
pip install cohere # or poetry add cohere
44+
```
45+
46+
Before using the reranking function, you need to obtain [Cohere API](https://dashboard.cohere.com/api-keys) key and set the `COHERE_API_KEY` environment variable.
47+
48+
> [!TIP]
49+
> By default, the reranking function will return distances. If you need to get the raw scores, set the `raw_scores` parameter to `True`.
50+
51+
```python
52+
import os
53+
import chromadb
54+
from chromadbx.reranking import CohereReranker
55+
56+
cohere = CohereReranker(api_key=os.getenv("COHERE_API_KEY"))
57+
58+
client = chromadb.Client()
59+
60+
collection = client.get_collection("documents")
61+
62+
results = collection.query(
63+
query_texts=["What is the capital of the United States?"],
64+
n_results=10,
65+
)
66+
67+
reranked_results = cohere(results)
68+
```
69+
70+
Available options:
71+
72+
- `api_key`: The Cohere API key.
73+
- `model_name`: The Cohere model to use for reranking. Defaults to `rerank-v3.5`.
74+
- `raw_scores`: Whether to return the raw scores from the Cohere API. Defaults to `False`.
75+
- `top_n`: The number of results to return. Defaults to `None`.
76+
- `max_tokens_per_document`: The maximum number of tokens per document. Defaults to `4096`.
77+
- `timeout`: The timeout for the Cohere API request. Defaults to `60`.
78+
- `max_retries`: The maximum number of retries for the Cohere API request. Defaults to `3`.
79+
- `additional_headers`: Additional headers to include in the Cohere API request. Defaults to `None`.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ llama-embedder = "^0.0.7"
3434
mistralai = "^1.1.0"
3535
spacy = "^3.8.4"
3636
together = "^1.3.11"
37+
cohere = "^5.13.8"
3738

3839
[tool.poetry.extras]
3940
ids = ["ulid-py", "nanoid"]
4041
embeddings = ["llama-embedder", "onnxruntime", "huggingface_hub", "mistralai", "spacy", "together", "vertexai"]
42+
reranking = ["cohere"]
4143
core = ["chromadb"]
4244

4345
[build-system]

test/reranking/test_cohere.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import os
2+
from typing import cast
23

4+
from chromadb import QueryResult
35
import pytest
6+
from chromadbx.reranking import RerankedDocuments, RerankedQueryResult
47
from chromadbx.reranking.cohere import CohereReranker
58

69

@@ -38,5 +41,30 @@ def test_cohere_rerank_documents() -> None:
3841
cohere = CohereReranker(api_key=os.getenv("COHERE_API_KEY", ""))
3942
queries = "What is the capital of the United States?"
4043
rerankables = ["Washington, D.C.", "New York", "Los Angeles"]
41-
result = cohere(queries, rerankables)
42-
print(result)
44+
result = cast(RerankedDocuments, cohere(queries, rerankables))
45+
assert "ranked_distances" in result
46+
assert len(result["ranked_distances"][cohere.id()]) == len(rerankables)
47+
assert result["ranked_distances"][cohere.id()].index(
48+
min(result["ranked_distances"][cohere.id()])
49+
) == rerankables.index("Washington, D.C.")
50+
51+
52+
@pytest.mark.skipif(
53+
os.getenv("COHERE_API_KEY") is None,
54+
reason="COHERE_API_KEY environment variable is not set",
55+
)
56+
def test_cohere_rerank_documents_with_query_result() -> None:
57+
cohere = CohereReranker(api_key=os.getenv("COHERE_API_KEY", ""))
58+
queries = ["What is the capital of the United States?"]
59+
rerankables = QueryResult(
60+
documents=[["Washington, D.C.", "New York", "Los Angeles"]],
61+
metadatas=[[{"source": "test"}, {"source": "test"}, {"source": "test"}]],
62+
embeddings=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
63+
ids=[["id1", "id2", "id3"]],
64+
)
65+
result = cast(RerankedQueryResult, cohere(queries, rerankables))
66+
assert "ranked_distances" in result
67+
assert len(result["ranked_distances"][cohere.id()]) == len(rerankables["ids"][0])
68+
assert result["ranked_distances"][cohere.id()].index(
69+
min(result["ranked_distances"][cohere.id()])
70+
) == rerankables["ids"][0].index("id1")

0 commit comments

Comments
 (0)