Skip to content

Commit e717795

Browse files
authored
[RAG-103] Init AstraDB LlamaIndex tests (#149)
1 parent 9b7b348 commit e717795

File tree

2 files changed

+217
-15
lines changed

2 files changed

+217
-15
lines changed

ragstack-e2e-tests/e2e_tests/langchain/test_astra.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_ingest_errors(environment):
3131
empty_text = ""
3232

3333
try:
34-
# empty test is not allowed
34+
# empty text computes embeddings vector as all zeroes and this is not allowed
3535
vectorstore.add_texts([empty_text])
3636
except ValueError as e:
3737
print("Error:", e)
@@ -56,20 +56,6 @@ def test_ingest_errors(environment):
5656
f"Should have thrown ValueError with SHRED_DOC_LIMIT_VIOLATION but it was {e}"
5757
)
5858

59-
very_very_long_text = (
60-
"RAGStack is a framework to run LangChain in production. " * 10000
61-
)
62-
try:
63-
vectorstore.add_texts([very_very_long_text])
64-
pytest.fail("Should have thrown ValueError")
65-
except ValueError as e:
66-
print("Error:", e)
67-
# API Exception while running bulk insertion: {'errors': [{'message': 'Document size limitation violated: String value length (560000) exceeds maximum allowed (16000)', 'errorCode': 'SHRED_DOC_LIMIT_VIOLATION'}]}
68-
if "SHRED_DOC_LIMIT_VIOLATION" not in e.args[0]:
69-
pytest.fail(
70-
f"Should have thrown ValueError with SHRED_DOC_LIMIT_VIOLATION but it was {e}"
71-
)
72-
7359

7460
def test_wrong_connection_parameters():
7561
# This is expected to be a valid endpoint, because we want to test an AUTHENTICATION error
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import logging
2+
from typing import List
3+
4+
from astrapy.db import AstraDB as LibAstraDB
5+
import pytest
6+
from httpx import ConnectError
7+
from e2e_tests.conftest import get_required_env, get_astra_ref
8+
from llama_index import (
9+
ServiceContext,
10+
StorageContext,
11+
VectorStoreIndex,
12+
Document,
13+
)
14+
from llama_index.embeddings import BaseEmbedding
15+
from llama_index.llms import OpenAI, LLM
16+
from llama_index.node_parser import SimpleNodeParser
17+
from llama_index.vector_stores import AstraDBVectorStore
18+
19+
20+
def test_basic_vector_search(environment):
21+
print("Running test_basic_vector_search")
22+
documents = [
23+
Document(text="RAGStack is a framework to run LangChain in production")
24+
]
25+
26+
index = VectorStoreIndex.from_documents(
27+
documents,
28+
storage_context=environment.storage_context,
29+
service_context=environment.service_context,
30+
)
31+
32+
# Verify that the document is in the vector store
33+
retriever = index.as_retriever()
34+
assert len(retriever.retrieve("RAGStack")) > 0
35+
36+
37+
def test_ingest_errors(environment):
38+
print("Running test_ingest_errors")
39+
40+
empty_text = ""
41+
42+
try:
43+
# empty text computes embeddings vector as all zeroes and this is not allowed
44+
documents = [Document(text=empty_text)]
45+
46+
VectorStoreIndex.from_documents(
47+
documents,
48+
storage_context=environment.storage_context,
49+
service_context=environment.service_context,
50+
)
51+
except ValueError as e:
52+
print("Error:", e)
53+
# API Exception while running bulk insertion: [{'message': "Failed to insert document with _id 'b388435404254c17b720816ee9e0ddc4': Zero vectors cannot be indexed or queried with cosine similarity"}]
54+
if (
55+
"Zero vectors cannot be indexed or queried with cosine similarity"
56+
not in e.args[0]
57+
):
58+
pytest.fail(
59+
f"Should have thrown ValueError with Zero vectors cannot be indexed or queried with cosine similarity but it was {e}"
60+
)
61+
62+
very_long_text = "RAGStack is a framework to run LangChain in production. " * 1000
63+
64+
# with the default set of transformations this write succeeds because LI automatically does text splitting
65+
documents = [Document(text=very_long_text)]
66+
VectorStoreIndex.from_documents(
67+
documents,
68+
storage_context=environment.storage_context,
69+
service_context=environment.service_context,
70+
)
71+
72+
# if we disable text splitting, this write fails because the document is too long
73+
very_long_text = "RAGStack is a framework to run LangChain in production. " * 1000
74+
try:
75+
documents = [Document(text=very_long_text)]
76+
77+
VectorStoreIndex.from_documents(
78+
documents,
79+
storage_context=environment.storage_context,
80+
service_context=environment.service_context_no_splitting,
81+
)
82+
pytest.fail("Should have thrown ValueError")
83+
except ValueError as e:
84+
print("Error:", e)
85+
# API Exception while running bulk insertion: {'errors': [{'message': 'Document size limitation violated: String value length (56000) exceeds maximum allowed (16000)', 'errorCode': 'SHRED_DOC_LIMIT_VIOLATION'}]}
86+
if "SHRED_DOC_LIMIT_VIOLATION" not in e.args[0]:
87+
pytest.fail(
88+
f"Should have thrown ValueError with SHRED_DOC_LIMIT_VIOLATION but it was {e}"
89+
)
90+
91+
92+
def test_wrong_connection_parameters():
93+
# This is expected to be a valid endpoint, because we want to test an AUTHENTICATION error
94+
astra_ref = get_astra_ref()
95+
api_endpoint = astra_ref.api_endpoint
96+
97+
try:
98+
AstraDBVectorStore(
99+
token="xxxxx",
100+
# we assume that post 1234 is not open locally
101+
api_endpoint="https://locahost:1234",
102+
collection_name="something",
103+
embedding_dimension=1536,
104+
)
105+
pytest.fail("Should have thrown exception")
106+
except ConnectError as e:
107+
print("Error:", e)
108+
pass
109+
110+
try:
111+
print("api_endpoint:", api_endpoint)
112+
AstraDBVectorStore(
113+
token="this-is-a-wrong-token",
114+
api_endpoint=api_endpoint,
115+
collection_name="something",
116+
embedding_dimension=1536,
117+
)
118+
pytest.fail("Should have thrown exception")
119+
except ValueError as e:
120+
print("Error:", e)
121+
if "AUTHENTICATION ERROR" not in e.args[0]:
122+
pytest.fail(
123+
f"Should have thrown ValueError with AUTHENTICATION ERROR but it was {e}"
124+
)
125+
126+
127+
def init_vector_db() -> AstraDBVectorStore:
128+
astra_ref = get_astra_ref()
129+
collection = astra_ref.collection
130+
token = astra_ref.token
131+
api_endpoint = astra_ref.api_endpoint
132+
133+
raw_client = LibAstraDB(api_endpoint=api_endpoint, token=token)
134+
collections = raw_client.get_collections().get("status").get("collections")
135+
logging.info(f"Existing collections: {collections}")
136+
for collection_info in collections:
137+
try:
138+
logging.info(f"Deleting collection: {collection_info}")
139+
raw_client.delete_collection(collection_info)
140+
except Exception as e:
141+
logging.error(f"Error while deleting collection {collection_info}: {e}")
142+
143+
vector_db = AstraDBVectorStore(
144+
token=token,
145+
api_endpoint=api_endpoint,
146+
collection_name=collection,
147+
embedding_dimension=3,
148+
)
149+
150+
return vector_db
151+
152+
153+
class Environment:
154+
def __init__(
155+
self, vectorstore: AstraDBVectorStore, llm: LLM, embedding: BaseEmbedding
156+
):
157+
self.vectorstore = vectorstore
158+
self.llm = llm
159+
self.embedding = embedding
160+
self.service_context = ServiceContext.from_defaults(
161+
embed_model=self.embedding, llm=self.llm
162+
)
163+
basic_node_parser = SimpleNodeParser.from_defaults(
164+
chunk_size=100000000, include_prev_next_rel=False, include_metadata=True
165+
)
166+
self.service_context_no_splitting = ServiceContext.from_defaults(
167+
embed_model=self.embedding,
168+
llm=self.llm,
169+
transformations=[basic_node_parser],
170+
)
171+
self.storage_context = StorageContext.from_defaults(vector_store=vectorstore)
172+
173+
174+
@pytest.fixture
175+
def environment():
176+
embeddings_impl = init_embeddings()
177+
vector_db_impl = init_vector_db()
178+
llm_impl = init_llm()
179+
yield Environment(
180+
vectorstore=vector_db_impl, llm=llm_impl, embedding=embeddings_impl
181+
)
182+
close_vector_db(vector_db_impl)
183+
184+
185+
def close_vector_db(vector_store: AstraDBVectorStore):
186+
vector_store._astra_db.delete_collection(
187+
vector_store._astra_db_collection.collection_name
188+
)
189+
190+
191+
class MockEmbeddings(BaseEmbedding):
192+
def _get_query_embedding(self, query: str) -> List[float]:
193+
return self.mock_embedding(query)
194+
195+
async def _aget_query_embedding(self, query: str) -> List[float]:
196+
return self.mock_embedding(query)
197+
198+
def _get_text_embedding(self, text: str) -> List[float]:
199+
return self.mock_embedding(text)
200+
201+
@staticmethod
202+
def mock_embedding(text: str):
203+
res = [len(text) / 2, len(text) / 5, len(text) / 10]
204+
logging.info("mock_embedding for " + text + " : " + str(res))
205+
return res
206+
207+
208+
def init_embeddings() -> BaseEmbedding:
209+
return MockEmbeddings()
210+
211+
212+
def init_llm() -> LLM:
213+
openai_key = get_required_env("OPEN_AI_KEY")
214+
return OpenAI(
215+
api_key=openai_key, model="gpt-3.5-turbo-16k", streaming=False, temperature=0
216+
)

0 commit comments

Comments
 (0)