Skip to content

Commit 2cb6dc3

Browse files
committed
Fix custom vector store notebook
1 parent bad577c commit 2cb6dc3

File tree

5 files changed

+22
-21
lines changed

5 files changed

+22
-21
lines changed

docs/examples_notebooks/custom_vector_store.ipynb

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
"import numpy as np\n",
6262
"import yaml\n",
6363
"\n",
64+
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
6465
"from graphrag.data_model.types import TextEmbedder\n",
6566
"\n",
6667
"# GraphRAG vector store components\n",
@@ -147,14 +148,12 @@
147148
" self.vectors: dict[str, np.ndarray] = {}\n",
148149
" self.connected = False\n",
149150
"\n",
150-
" print(\n",
151-
" f\"🚀 SimpleInMemoryVectorStore initialized for collection: {self.collection_name}\"\n",
152-
" )\n",
151+
" print(f\"🚀 SimpleInMemoryVectorStore initialized for index: {self.index_name}\")\n",
153152
"\n",
154153
" def connect(self, **kwargs: Any) -> None:\n",
155154
" \"\"\"Connect to the vector storage (no-op for in-memory store).\"\"\"\n",
156155
" self.connected = True\n",
157-
" print(f\"✅ Connected to in-memory vector store: {self.collection_name}\")\n",
156+
" print(f\"✅ Connected to in-memory vector store: {self.index_name}\")\n",
158157
"\n",
159158
" def load_documents(\n",
160159
" self, documents: list[VectorStoreDocument], overwrite: bool = True\n",
@@ -250,7 +249,7 @@
250249
" def get_stats(self) -> dict[str, Any]:\n",
251250
" \"\"\"Get statistics about the vector store (custom method).\"\"\"\n",
252251
" return {\n",
253-
" \"collection_name\": self.collection_name,\n",
252+
" \"index_name\": self.index_name,\n",
254253
" \"document_count\": len(self.documents),\n",
255254
" \"vector_count\": len(self.vectors),\n",
256255
" \"connected\": self.connected,\n",
@@ -353,11 +352,11 @@
353352
"outputs": [],
354353
"source": [
355354
"# Test creating vector store using the factory\n",
356-
"vector_store_config = {\"collection_name\": \"test_collection\"}\n",
355+
"schema = VectorStoreSchemaConfig(index_name=\"test_collection\")\n",
357356
"\n",
358357
"# Create vector store instance using factory\n",
359358
"vector_store = VectorStoreFactory.create_vector_store(\n",
360-
" CUSTOM_VECTOR_STORE_TYPE, vector_store_config\n",
359+
" CUSTOM_VECTOR_STORE_TYPE, vector_store_schema_config=schema\n",
361360
")\n",
362361
"\n",
363362
"print(f\"✅ Created vector store instance: {type(vector_store).__name__}\")\n",
@@ -486,9 +485,13 @@
486485
" print(\"🚀 Simulating GraphRAG pipeline with custom vector store...\\n\")\n",
487486
"\n",
488487
" # 1. GraphRAG creates vector store using factory\n",
489-
" config = {\"collection_name\": \"graphrag_entities\", \"similarity_threshold\": 0.3}\n",
488+
" schema = VectorStoreSchemaConfig(index_name=\"graphrag_entities\")\n",
490489
"\n",
491-
" store = VectorStoreFactory.create_vector_store(CUSTOM_VECTOR_STORE_TYPE, config)\n",
490+
" store = VectorStoreFactory.create_vector_store(\n",
491+
" CUSTOM_VECTOR_STORE_TYPE,\n",
492+
" vector_store_schema_config=schema,\n",
493+
" similarity_threshold=0.3,\n",
494+
" )\n",
492495
" store.connect()\n",
493496
"\n",
494497
" print(\"✅ Step 1: Vector store created and connected\")\n",
@@ -549,7 +552,8 @@
549552
" # Test 1: Basic functionality\n",
550553
" print(\"Test 1: Basic functionality\")\n",
551554
" store = VectorStoreFactory.create_vector_store(\n",
552-
" CUSTOM_VECTOR_STORE_TYPE, {\"collection_name\": \"test\"}\n",
555+
" CUSTOM_VECTOR_STORE_TYPE,\n",
556+
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test\"),\n",
553557
" )\n",
554558
" store.connect()\n",
555559
"\n",
@@ -597,7 +601,8 @@
597601
" # Test 5: Error handling\n",
598602
" print(\"\\nTest 5: Error handling\")\n",
599603
" disconnected_store = VectorStoreFactory.create_vector_store(\n",
600-
" CUSTOM_VECTOR_STORE_TYPE, {\"collection_name\": \"test2\"}\n",
604+
" CUSTOM_VECTOR_STORE_TYPE,\n",
605+
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test2\"),\n",
601606
" )\n",
602607
"\n",
603608
" try:\n",
@@ -653,7 +658,7 @@
653658
],
654659
"metadata": {
655660
"kernelspec": {
656-
"display_name": "graphrag-venv (3.10.18)",
661+
"display_name": "graphrag",
657662
"language": "python",
658663
"name": "python3"
659664
},
@@ -667,7 +672,7 @@
667672
"name": "python",
668673
"nbconvert_exporter": "python",
669674
"pygments_lexer": "ipython3",
670-
"version": "3.10.18"
675+
"version": "3.12.10"
671676
}
672677
},
673678
"nbformat": 4,

graphrag/index/operations/embed_text/embed_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def _create_vector_store(
210210
vector_store = VectorStoreFactory().create_vector_store(
211211
vector_store_schema_config=single_embedding_config,
212212
vector_store_type=vector_store_type,
213-
kwargs=vector_store_config,
213+
**vector_store_config,
214214
)
215215

216216
vector_store.connect(**vector_store_config)

graphrag/utils/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def get_embedding_store(
130130
embedding_store = VectorStoreFactory().create_vector_store(
131131
vector_store_type=vector_store_type,
132132
vector_store_schema_config=single_embedding_config,
133-
kwargs={**store},
133+
**store,
134134
)
135135
embedding_store.connect(**store)
136136
# If there is only a single index, return the embedding store directly

graphrag/vector_stores/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def create_vector_store(
5353
cls,
5454
vector_store_type: str,
5555
vector_store_schema_config: VectorStoreSchemaConfig,
56-
kwargs: dict,
56+
**kwargs: dict,
5757
) -> BaseVectorStore:
5858
"""Create a vector store object from the provided type.
5959

tests/integration/vector_stores/test_factory.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ def test_register_and_create_custom_vector_store():
8181
)
8282

8383
vector_store = VectorStoreFactory.create_vector_store(
84-
vector_store_type="custom",
85-
vector_store_schema_config=VectorStoreSchemaConfig(),
86-
kwargs={},
84+
vector_store_type="custom", vector_store_schema_config=VectorStoreSchemaConfig()
8785
)
8886

8987
assert custom_vector_store_class.called
@@ -109,7 +107,6 @@ def test_create_unknown_vector_store():
109107
VectorStoreFactory.create_vector_store(
110108
vector_store_type="unknown",
111109
vector_store_schema_config=VectorStoreSchemaConfig(),
112-
kwargs={},
113110
)
114111

115112

@@ -162,7 +159,6 @@ def search_by_id(self, id):
162159
vector_store = VectorStoreFactory.create_vector_store(
163160
vector_store_type="custom_class",
164161
vector_store_schema_config=VectorStoreSchemaConfig(),
165-
kwargs={"collection_name": "test"},
166162
)
167163

168164
assert isinstance(vector_store, CustomVectorStore)

0 commit comments

Comments
 (0)