Skip to content

Commit b9b8995

Browse files
authored
fix colbert rank calculation (#399)
1 parent 093d24a commit b9b8995

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

libs/colbert/ragstack_colbert/colbert_retriever.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,16 +246,16 @@ async def _fetch_chunk_data(
246246
chunk_data_map[chunk] = chunk_data
247247

248248
answers: List[RetrievedChunk] = []
249-
rank = 1
250-
for chunk in chunks_by_score:
249+
250+
for idx, chunk in enumerate(chunks_by_score):
251251
score = chunk_scores[chunk]
252252
chunk_data = chunk_data_map[chunk]
253253
answers.append(
254254
RetrievedChunk(
255255
doc_id=chunk.doc_id,
256256
chunk_id=chunk.chunk_id,
257257
score=score.item(), # Ensure score is a scalar if it's a tensor
258-
rank=rank,
258+
rank=idx + 1,
259259
data=chunk_data,
260260
)
261261
)

libs/colbert/tests/integration_tests/test_colbert_embedding_retrieval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,6 @@ def generate_texts(text, chunk_size, overlap_size):
9696
for chunk in chunks:
9797
logging.info(f"got {chunk}")
9898
assert len(chunks) == 5
99-
assert len(chunks[0].data.text) > 0
99+
assert len(chunks[0].data.text) > 0
100+
assert chunks[0].rank == 1
101+
assert chunks[1].rank == 2

0 commit comments

Comments
 (0)