Skip to content

Commit 61769dd

Browse files
Vector Store Integration Tests (#1856)
* Add vector store id reference to embeddings config. * generated initial vector store pytests * cleaned up cosmosdb vector store test * fixed class name typo and debugged cosmosdb vector store test * reset emulator connection string * remove unneccessary comments * removed extra comments from azure ai search test * ruff * semversioner * fix cicd issues * bypass diskANN policy for test env * handle floating point inprecisions --------- Co-authored-by: Derek Worthen <[email protected]>
1 parent ffd8db7 commit 61769dd

File tree

7 files changed

+503
-18
lines changed

7 files changed

+503
-18
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "add vector store integration tests"
4+
}

graphrag/vector_stores/cosmosdb.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any
88

99
from azure.cosmos import ContainerProxy, CosmosClient, DatabaseProxy
10+
from azure.cosmos.exceptions import CosmosHttpResponseError
1011
from azure.cosmos.partition_key import PartitionKey
1112
from azure.identity import DefaultAzureCredential
1213

@@ -19,7 +20,7 @@
1920
)
2021

2122

22-
class CosmosDBVectoreStore(BaseVectorStore):
23+
class CosmosDBVectorStore(BaseVectorStore):
2324
"""Azure CosmosDB vector storage implementation."""
2425

2526
_cosmos_client: CosmosClient
@@ -99,16 +100,32 @@ def _create_container(self) -> None:
99100
"automatic": True,
100101
"includedPaths": [{"path": "/*"}],
101102
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}],
102-
"vectorIndexes": [{"path": "/vector", "type": "diskANN"}],
103103
}
104104

105-
# Create the container and container client
106-
self._database_client.create_container_if_not_exists(
107-
id=self._container_name,
108-
partition_key=partition_key,
109-
indexing_policy=indexing_policy,
110-
vector_embedding_policy=vector_embedding_policy,
111-
)
105+
# Currently, the CosmosDB emulator does not support the diskANN policy.
106+
try:
107+
# First try with the standard diskANN policy
108+
indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}]
109+
110+
# Create the container and container client
111+
self._database_client.create_container_if_not_exists(
112+
id=self._container_name,
113+
partition_key=partition_key,
114+
indexing_policy=indexing_policy,
115+
vector_embedding_policy=vector_embedding_policy,
116+
)
117+
except CosmosHttpResponseError:
118+
# If diskANN fails (likely in emulator), retry without vector indexes
119+
indexing_policy.pop("vectorIndexes", None)
120+
121+
# Create the container with compatible indexing policy
122+
self._database_client.create_container_if_not_exists(
123+
id=self._container_name,
124+
partition_key=partition_key,
125+
indexing_policy=indexing_policy,
126+
vector_embedding_policy=vector_embedding_policy,
127+
)
128+
112129
self._container_client = self._database_client.get_container_client(
113130
self._container_name
114131
)
@@ -157,13 +174,46 @@ def similarity_search_by_vector(
157174
msg = "Container client is not initialized."
158175
raise ValueError(msg)
159176

160-
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
161-
query_params = [{"name": "@embedding", "value": query_embedding}]
162-
items = self._container_client.query_items(
163-
query=query,
164-
parameters=query_params,
165-
enable_cross_partition_query=True,
166-
)
177+
try:
178+
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
179+
query_params = [{"name": "@embedding", "value": query_embedding}]
180+
items = list(
181+
self._container_client.query_items(
182+
query=query,
183+
parameters=query_params,
184+
enable_cross_partition_query=True,
185+
)
186+
)
187+
except (CosmosHttpResponseError, ValueError):
188+
# Currently, the CosmosDB emulator does not support the VectorDistance function.
189+
# For emulator or test environments - fetch all items and calculate distance locally
190+
query = "SELECT c.id, c.text, c.vector, c.attributes FROM c"
191+
items = list(
192+
self._container_client.query_items(
193+
query=query,
194+
enable_cross_partition_query=True,
195+
)
196+
)
197+
198+
# Calculate cosine similarity locally (1 - cosine distance)
199+
from numpy import dot
200+
from numpy.linalg import norm
201+
202+
def cosine_similarity(a, b):
203+
if norm(a) * norm(b) == 0:
204+
return 0.0
205+
return dot(a, b) / (norm(a) * norm(b))
206+
207+
# Calculate scores for all items
208+
for item in items:
209+
item_vector = item.get("vector", [])
210+
similarity = cosine_similarity(query_embedding, item_vector)
211+
item["SimilarityScore"] = similarity
212+
213+
# Sort by similarity score (higher is better) and take top k
214+
items = sorted(
215+
items, key=lambda x: x.get("SimilarityScore", 0.0), reverse=True
216+
)[:k]
167217

168218
return [
169219
VectorStoreSearchResult(
@@ -214,3 +264,8 @@ def search_by_id(self, id: str) -> VectorStoreDocument:
214264
text=item.get("text", ""),
215265
attributes=(json.loads(item.get("attributes", "{}"))),
216266
)
267+
268+
def clear(self) -> None:
269+
"""Clear the vector store."""
270+
self._delete_container()
271+
self._delete_database()

graphrag/vector_stores/factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
1010
from graphrag.vector_stores.base import BaseVectorStore
11-
from graphrag.vector_stores.cosmosdb import CosmosDBVectoreStore
11+
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
1212
from graphrag.vector_stores.lancedb import LanceDBVectorStore
1313

1414

@@ -44,7 +44,7 @@ def create_vector_store(
4444
case VectorStoreType.AzureAISearch:
4545
return AzureAISearchVectorStore(**kwargs)
4646
case VectorStoreType.CosmosDB:
47-
return CosmosDBVectoreStore(**kwargs)
47+
return CosmosDBVectorStore(**kwargs)
4848
case _:
4949
if vector_store_type in cls.vector_store_types:
5050
return cls.vector_store_types[vector_store_type](**kwargs)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""Integration tests for vector store implementations."""
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""Integration tests for Azure AI Search vector store implementation."""
5+
6+
import os
7+
from unittest.mock import MagicMock, patch
8+
9+
import pytest
10+
11+
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
12+
from graphrag.vector_stores.base import VectorStoreDocument
13+
14+
TEST_AZURE_AI_SEARCH_URL = os.environ.get(
15+
"TEST_AZURE_AI_SEARCH_URL", "https://test-url.search.windows.net"
16+
)
17+
TEST_AZURE_AI_SEARCH_KEY = os.environ.get("TEST_AZURE_AI_SEARCH_KEY", "test_api_key")
18+
19+
20+
class TestAzureAISearchVectorStore:
21+
"""Test class for AzureAISearchVectorStore."""
22+
23+
@pytest.fixture
24+
def mock_search_client(self):
25+
"""Create a mock Azure AI Search client."""
26+
with patch(
27+
"graphrag.vector_stores.azure_ai_search.SearchClient"
28+
) as mock_client:
29+
yield mock_client.return_value
30+
31+
@pytest.fixture
32+
def mock_index_client(self):
33+
"""Create a mock Azure AI Search index client."""
34+
with patch(
35+
"graphrag.vector_stores.azure_ai_search.SearchIndexClient"
36+
) as mock_client:
37+
yield mock_client.return_value
38+
39+
@pytest.fixture
40+
def vector_store(self, mock_search_client, mock_index_client):
41+
"""Create an Azure AI Search vector store instance."""
42+
vector_store = AzureAISearchVectorStore(collection_name="test_vectors")
43+
44+
# Create the necessary mocks first
45+
vector_store.db_connection = mock_search_client
46+
vector_store.index_client = mock_index_client
47+
48+
vector_store.connect(
49+
url=TEST_AZURE_AI_SEARCH_URL,
50+
api_key=TEST_AZURE_AI_SEARCH_KEY,
51+
vector_size=5,
52+
)
53+
return vector_store
54+
55+
@pytest.fixture
56+
def sample_documents(self):
57+
"""Create sample documents for testing."""
58+
return [
59+
VectorStoreDocument(
60+
id="doc1",
61+
text="This is document 1",
62+
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
63+
attributes={"title": "Doc 1", "category": "test"},
64+
),
65+
VectorStoreDocument(
66+
id="doc2",
67+
text="This is document 2",
68+
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
69+
attributes={"title": "Doc 2", "category": "test"},
70+
),
71+
]
72+
73+
async def test_vector_store_operations(
74+
self, vector_store, sample_documents, mock_search_client, mock_index_client
75+
):
76+
"""Test basic vector store operations with Azure AI Search."""
77+
# Setup mock responses
78+
mock_index_client.list_index_names.return_value = []
79+
mock_index_client.create_or_update_index = MagicMock()
80+
mock_search_client.upload_documents = MagicMock()
81+
82+
search_results = [
83+
{
84+
"id": "doc1",
85+
"text": "This is document 1",
86+
"vector": [0.1, 0.2, 0.3, 0.4, 0.5],
87+
"attributes": '{"title": "Doc 1", "category": "test"}',
88+
"@search.score": 0.9,
89+
},
90+
{
91+
"id": "doc2",
92+
"text": "This is document 2",
93+
"vector": [0.2, 0.3, 0.4, 0.5, 0.6],
94+
"attributes": '{"title": "Doc 2", "category": "test"}',
95+
"@search.score": 0.8,
96+
},
97+
]
98+
mock_search_client.search.return_value = search_results
99+
100+
mock_search_client.get_document.return_value = {
101+
"id": "doc1",
102+
"text": "This is document 1",
103+
"vector": [0.1, 0.2, 0.3, 0.4, 0.5],
104+
"attributes": '{"title": "Doc 1", "category": "test"}',
105+
}
106+
107+
vector_store.load_documents(sample_documents)
108+
assert mock_index_client.create_or_update_index.called
109+
assert mock_search_client.upload_documents.called
110+
111+
filter_query = vector_store.filter_by_id(["doc1", "doc2"])
112+
assert filter_query == "search.in(id, 'doc1,doc2', ',')"
113+
114+
vector_results = vector_store.similarity_search_by_vector(
115+
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
116+
)
117+
assert len(vector_results) == 2
118+
assert vector_results[0].document.id == "doc1"
119+
assert vector_results[0].score == 0.9
120+
121+
# Define a simple text embedder function for testing
122+
def mock_embedder(text: str) -> list[float]:
123+
return [0.1, 0.2, 0.3, 0.4, 0.5]
124+
125+
text_results = vector_store.similarity_search_by_text(
126+
"test query", mock_embedder, k=2
127+
)
128+
assert len(text_results) == 2
129+
130+
doc = vector_store.search_by_id("doc1")
131+
assert doc.id == "doc1"
132+
assert doc.text == "This is document 1"
133+
assert doc.attributes["title"] == "Doc 1"
134+
135+
async def test_empty_embedding(self, vector_store, mock_search_client):
136+
"""Test similarity search by text with empty embedding."""
137+
138+
# Create a mock embedder that returns None and verify that no results are produced
139+
def none_embedder(text: str) -> None:
140+
return None
141+
142+
results = vector_store.similarity_search_by_text(
143+
"test query", none_embedder, k=1
144+
)
145+
assert not mock_search_client.search.called
146+
assert len(results) == 0

0 commit comments

Comments
 (0)