Skip to content

Commit 093d24a

Browse files
authored
colbert integration test on cassandra (#396)
* colbert integration test on cassandra * use nested event_loop * fix file name typo * add comments
1 parent 2a02d34 commit 093d24a

File tree

5 files changed

+8
-3
lines changed

5 files changed

+8
-3
lines changed

libs/colbert/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ colbert-ai = "0.2.19"
1515
pyarrow = "14.0.1"
1616
torch = "2.2.1"
1717
cassio = "~0.1.7"
18+
nest-asyncio = "^1.6.0"
1819

1920
[tool.poetry.group.test.dependencies]
2021
ragstack-ai-tests-utils = { path = "../tests-utils", develop = true }

libs/colbert/ragstack_colbert/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717
from .cassandra_vector_store import CassandraVectorStore
18-
from .cobert_retriever import ColbertRetriever
18+
from .colbert_retriever import ColbertRetriever
1919
from .colbert_embedding_model import ColbertEmbeddingModel
2020
from .constant import DEFAULT_COLBERT_DIM, DEFAULT_COLBERT_MODEL
2121
from .objects import ChunkData, EmbeddedChunk, RetrievedChunk

libs/colbert/ragstack_colbert/cobert_retriever.py renamed to libs/colbert/ragstack_colbert/colbert_retriever.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import asyncio
1212
import logging
1313
import math
14+
import nest_asyncio
1415
from typing import Any, Dict, List, Optional, Set
1516

1617
import torch
@@ -336,5 +337,8 @@ def retrieve(
336337
The actual retrieval process involves encoding the query, performing an ANN search to find relevant
337338
embeddings, scoring these embeddings for similarity, and retrieving the corresponding text chunks.
338339
"""
340+
# nest_asyncio does not a new event loop to be created
341+
# in the case there is already an event loop such as colab, it's required
342+
nest_asyncio.apply()
339343
loop = asyncio.get_event_loop()
340344
return loop.run_until_complete(self.aretrieve(query=query, k=k, query_maxlen=query_maxlen))

libs/colbert/tests/integration_tests/test_colbert_embedding_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def astra_db():
2525

2626

2727
#@pytest.mark.parametrize("vector_store", ["cassandra", "astra_db"])
28-
@pytest.mark.parametrize("vector_store", ["astra_db"])
28+
@pytest.mark.parametrize("vector_store", ["cassandra"])
2929
def test_embedding_cassandra_retriever(request, vector_store: str):
3030
vector_store = request.getfixturevalue(vector_store)
3131
narrative = """

libs/colbert/tests/unit_tests/test_colbert_retriever.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from ragstack_colbert.cobert_retriever import max_similarity_torch
3+
from ragstack_colbert.colbert_retriever import max_similarity_torch
44
from ragstack_colbert.colbert_embedding_model import calculate_query_maxlen
55

66

0 commit comments

Comments
 (0)