Skip to content

Commit b2875ab

Browse files
Gaudy BlancoGaudy Blanco
authored andcommitted
test fixes
1 parent dc00cb8 commit b2875ab

File tree

6 files changed

+362
-139
lines changed

6 files changed

+362
-139
lines changed

.vscode/launch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
"module": "pytest",
4646
"args": [
4747
"./tests/integration/vector_stores",
48-
"-k", "test_lancedb"
48+
"-k", "test_azure_ai_search"
4949
],
5050
"console": "integratedTerminal",
5151
"justMyCode": false

graphrag/vector_stores/azure_ai_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
159159
# More info about odata filtering here: https://learn.microsoft.com/en-us/azure/search/search-query-odata-search-in-function
160160
# search.in is faster that joined and/or conditions
161161
id_filter = ",".join([f"{id!s}" for id in include_ids])
162-
self.query_filter = f"search.in(id, '{id_filter}', ',')"
162+
self.query_filter = f"search.in({self.id_field}, '{id_filter}', ',')"
163163

164164
# Returning to keep consistency with other methods, but not needed
165165
# TODO: Refactor on a future PR

graphrag/vector_stores/lancedb.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
9999
else:
100100
if isinstance(include_ids[0], str):
101101
id_filter = ", ".join([f"'{id}'" for id in include_ids])
102-
self.query_filter = f"id in ({id_filter})"
102+
self.query_filter = f"{self.id_field} in ({id_filter})"
103103
else:
104104
self.query_filter = (
105-
f"id in ({', '.join([str(id) for id in include_ids])})"
105+
f"{self.id_field} in ({', '.join([str(id) for id in include_ids])})"
106106
)
107107
return self.query_filter
108108

@@ -155,7 +155,7 @@ def search_by_id(self, id: str) -> VectorStoreDocument:
155155
"""Search for a document by id."""
156156
doc = (
157157
self.document_collection.search()
158-
.where(f"id == '{id}'", prefilter=True)
158+
.where(f"{self.id_field} == '{id}'", prefilter=True)
159159
.to_list()
160160
)
161161
if doc:

tests/integration/vector_stores/test_azure_ai_search.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,33 @@ def mock_index_client(self):
4141
def vector_store(self, mock_search_client, mock_index_client):
4242
"""Create an Azure AI Search vector store instance."""
4343
vector_store = AzureAISearchVectorStore(
44-
collection_name="test_vectors",
45-
vector_store_schema_config=VectorStoreSchemaConfig(),
44+
vector_store_schema_config=VectorStoreSchemaConfig(
45+
index_name="test_vectors", vector_size=5
46+
),
47+
)
48+
49+
# Create the necessary mocks first
50+
vector_store.db_connection = mock_search_client
51+
vector_store.index_client = mock_index_client
52+
53+
vector_store.connect(
54+
url=TEST_AZURE_AI_SEARCH_URL,
55+
api_key=TEST_AZURE_AI_SEARCH_KEY,
56+
)
57+
return vector_store
58+
59+
@pytest.fixture
60+
def vector_store_custom(self, mock_search_client, mock_index_client):
61+
"""Create an Azure AI Search vector store instance."""
62+
vector_store = AzureAISearchVectorStore(
63+
vector_store_schema_config=VectorStoreSchemaConfig(
64+
index_name="test_vectors",
65+
id_field="id_custom",
66+
text_field="text_custom",
67+
attributes_field="attributes_custom",
68+
vector_field="vector_custom",
69+
vector_size=5,
70+
),
4671
)
4772

4873
# Create the necessary mocks first
@@ -52,7 +77,6 @@ def vector_store(self, mock_search_client, mock_index_client):
5277
vector_store.connect(
5378
url=TEST_AZURE_AI_SEARCH_URL,
5479
api_key=TEST_AZURE_AI_SEARCH_KEY,
55-
vector_size=5,
5680
)
5781
return vector_store
5882

@@ -148,3 +172,72 @@ def none_embedder(text: str) -> None:
148172
)
149173
assert not mock_search_client.search.called
150174
assert len(results) == 0
175+
176+
async def test_vector_store_customization(
177+
self,
178+
vector_store_custom,
179+
sample_documents,
180+
mock_search_client,
181+
mock_index_client,
182+
):
183+
"""Test vector store customization with Azure AI Search."""
184+
# Setup mock responses
185+
mock_index_client.list_index_names.return_value = []
186+
mock_index_client.create_or_update_index = MagicMock()
187+
mock_search_client.upload_documents = MagicMock()
188+
189+
search_results = [
190+
{
191+
vector_store_custom.id_field: "doc1",
192+
vector_store_custom.text_field: "This is document 1",
193+
vector_store_custom.vector_field: [0.1, 0.2, 0.3, 0.4, 0.5],
194+
vector_store_custom.attributes_field: '{"title": "Doc 1", "category": "test"}',
195+
"@search.score": 0.9,
196+
},
197+
{
198+
vector_store_custom.id_field: "doc2",
199+
vector_store_custom.text_field: "This is document 2",
200+
vector_store_custom.vector_field: [0.2, 0.3, 0.4, 0.5, 0.6],
201+
vector_store_custom.attributes_field: '{"title": "Doc 2", "category": "test"}',
202+
"@search.score": 0.8,
203+
},
204+
]
205+
mock_search_client.search.return_value = search_results
206+
207+
mock_search_client.get_document.return_value = {
208+
vector_store_custom.id_field: "doc1",
209+
vector_store_custom.text_field: "This is document 1",
210+
vector_store_custom.vector_field: [0.1, 0.2, 0.3, 0.4, 0.5],
211+
vector_store_custom.attributes_field: '{"title": "Doc 1", "category": "test"}',
212+
}
213+
214+
vector_store_custom.load_documents(sample_documents)
215+
assert mock_index_client.create_or_update_index.called
216+
assert mock_search_client.upload_documents.called
217+
218+
filter_query = vector_store_custom.filter_by_id(["doc1", "doc2"])
219+
assert (
220+
filter_query
221+
== f"search.in({vector_store_custom.id_field}, 'doc1,doc2', ',')"
222+
)
223+
224+
vector_results = vector_store_custom.similarity_search_by_vector(
225+
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
226+
)
227+
assert len(vector_results) == 2
228+
assert vector_results[0].document.id == "doc1"
229+
assert vector_results[0].score == 0.9
230+
231+
# Define a simple text embedder function for testing
232+
def mock_embedder(text: str) -> list[float]:
233+
return [0.1, 0.2, 0.3, 0.4, 0.5]
234+
235+
text_results = vector_store_custom.similarity_search_by_text(
236+
"test query", mock_embedder, k=2
237+
)
238+
assert len(text_results) == 2
239+
240+
doc = vector_store_custom.search_by_id("doc1")
241+
assert doc.id == "doc1"
242+
assert doc.text == "This is document 1"
243+
assert doc.attributes["title"] == "Doc 1"

tests/integration/vector_stores/test_cosmosdb.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,64 @@ def test_clear():
103103
assert vector_store._database_exists() is False # noqa: SLF001
104104
finally:
105105
pass
106+
107+
108+
def test_vector_store_customization():
109+
"""Test vector store customization with CosmosDB."""
110+
vector_store = CosmosDBVectorStore(
111+
vector_store_schema_config=VectorStoreSchemaConfig(
112+
index_name="text-embeddings",
113+
id_field="id_custom",
114+
text_field="text_custom",
115+
vector_field="vector_custom",
116+
attributes_field="attributes_custom",
117+
vector_size=5,
118+
),
119+
)
120+
121+
try:
122+
vector_store.connect(
123+
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
124+
database_name="test_db",
125+
)
126+
127+
docs = [
128+
VectorStoreDocument(
129+
id="doc1",
130+
text="This is document 1",
131+
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
132+
attributes={"title": "Doc 1", "category": "test"},
133+
),
134+
VectorStoreDocument(
135+
id="doc2",
136+
text="This is document 2",
137+
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
138+
attributes={"title": "Doc 2", "category": "test"},
139+
),
140+
]
141+
vector_store.load_documents(docs)
142+
143+
vector_store.filter_by_id(["doc1"])
144+
145+
doc = vector_store.search_by_id("doc1")
146+
assert doc.id == "doc1"
147+
assert doc.text == "This is document 1"
148+
assert doc.vector is not None
149+
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
150+
assert doc.attributes["title"] == "Doc 1"
151+
152+
# Define a simple text embedder function for testing
153+
def mock_embedder(text: str) -> list[float]:
154+
return [0.1, 0.2, 0.3, 0.4, 0.5] # Return fixed embedding
155+
156+
vector_results = vector_store.similarity_search_by_vector(
157+
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
158+
)
159+
assert len(vector_results) > 0
160+
161+
text_results = vector_store.similarity_search_by_text(
162+
"test query", mock_embedder, k=2
163+
)
164+
assert len(text_results) > 0
165+
finally:
166+
vector_store.clear()

0 commit comments

Comments
 (0)