Skip to content

Commit 0436405

Browse files
gaudybGaudy Blanco
andauthored
Remove document overwrite (#2101)
* remove document overwrite from vector store configuration * remove document overwrite and refactor load documents method * fix test * fix test * fix test --------- Co-authored-by: Gaudy Blanco <[email protected]>
1 parent 5ec49fd commit 0436405

File tree

13 files changed

+1878
-1873
lines changed

13 files changed

+1878
-1873
lines changed

graphrag/config/models/vector_store_config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,6 @@ def _validate_url(self) -> None:
8282
default=vector_store_defaults.database_name,
8383
)
8484

85-
overwrite: bool = Field(
86-
description="Overwrite the existing data.",
87-
default=vector_store_defaults.overwrite,
88-
)
89-
9085
embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}
9186

9287
def _validate_embeddings_schema(self) -> None:

graphrag/index/operations/embed_text/embed_text.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,27 +44,21 @@ async def embed_text(
4444
msg = f"Column {id_column} not found in input dataframe with columns {input.columns}"
4545
raise ValueError(msg)
4646

47-
total_rows = 0
48-
for row in input[embed_column]:
49-
if isinstance(row, list):
50-
total_rows += len(row)
51-
else:
52-
total_rows += 1
47+
vector_store.create_index()
5348

54-
i = 0
55-
starting_index = 0
49+
index = 0
5650

5751
all_results = []
5852

5953
num_total_batches = (input.shape[0] + batch_size - 1) // batch_size
60-
while batch_size * i < input.shape[0]:
54+
while batch_size * index < input.shape[0]:
6155
logger.info(
6256
"uploading text embeddings batch %d/%d of size %d to vector store",
63-
i + 1,
57+
index + 1,
6458
num_total_batches,
6559
batch_size,
6660
)
67-
batch = input.iloc[batch_size * i : batch_size * (i + 1)]
61+
batch = input.iloc[batch_size * index : batch_size * (index + 1)]
6862
texts: list[str] = batch[embed_column].tolist()
6963
ids: list[str] = batch[id_column].tolist()
7064
result = await run_embed_text(
@@ -93,8 +87,7 @@ async def embed_text(
9387
)
9488
documents.append(document)
9589

96-
vector_store.load_documents(documents, True)
97-
starting_index += len(documents)
98-
i += 1
90+
vector_store.load_documents(documents)
91+
index += 1
9992

10093
return all_results

graphrag/vector_stores/azure_ai_search.py

Lines changed: 48 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -74,58 +74,57 @@ def connect(self, **kwargs: Any) -> Any:
7474
not_supported_error = "Azure AI Search expects `url`."
7575
raise ValueError(not_supported_error)
7676

77-
def load_documents(
78-
self, documents: list[VectorStoreDocument], overwrite: bool = True
79-
) -> None:
77+
def create_index(self) -> None:
8078
"""Load documents into an Azure AI Search index."""
81-
if overwrite:
82-
if (
83-
self.index_name is not None
84-
and self.index_name in self.index_client.list_index_names()
85-
):
86-
self.index_client.delete_index(self.index_name)
87-
88-
# Configure vector search profile
89-
vector_search = VectorSearch(
90-
algorithms=[
91-
HnswAlgorithmConfiguration(
92-
name="HnswAlg",
93-
parameters=HnswParameters(
94-
metric=VectorSearchAlgorithmMetric.COSINE
95-
),
96-
)
97-
],
98-
profiles=[
99-
VectorSearchProfile(
100-
name=self.vector_search_profile_name,
101-
algorithm_configuration_name="HnswAlg",
102-
)
103-
],
104-
)
105-
# Configure the index
106-
index = SearchIndex(
107-
name=self.index_name if self.index_name else "",
108-
fields=[
109-
SimpleField(
110-
name=self.id_field,
111-
type=SearchFieldDataType.String,
112-
key=True,
113-
),
114-
SearchField(
115-
name=self.vector_field,
116-
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
117-
searchable=True,
118-
hidden=False, # DRIFT needs to return the vector for client-side similarity
119-
vector_search_dimensions=self.vector_size,
120-
vector_search_profile_name=self.vector_search_profile_name,
79+
if (
80+
self.index_name is not None
81+
and self.index_name in self.index_client.list_index_names()
82+
):
83+
self.index_client.delete_index(self.index_name)
84+
85+
# Configure vector search profile
86+
vector_search = VectorSearch(
87+
algorithms=[
88+
HnswAlgorithmConfiguration(
89+
name="HnswAlg",
90+
parameters=HnswParameters(
91+
metric=VectorSearchAlgorithmMetric.COSINE
12192
),
122-
],
123-
vector_search=vector_search,
124-
)
125-
self.index_client.create_or_update_index(
126-
index,
127-
)
93+
)
94+
],
95+
profiles=[
96+
VectorSearchProfile(
97+
name=self.vector_search_profile_name,
98+
algorithm_configuration_name="HnswAlg",
99+
)
100+
],
101+
)
102+
# Configure the index
103+
index = SearchIndex(
104+
name=self.index_name if self.index_name else "",
105+
fields=[
106+
SimpleField(
107+
name=self.id_field,
108+
type=SearchFieldDataType.String,
109+
key=True,
110+
),
111+
SearchField(
112+
name=self.vector_field,
113+
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
114+
searchable=True,
115+
hidden=False, # DRIFT needs to return the vector for client-side similarity
116+
vector_search_dimensions=self.vector_size,
117+
vector_search_profile_name=self.vector_search_profile_name,
118+
),
119+
],
120+
vector_search=vector_search,
121+
)
122+
self.index_client.create_or_update_index(
123+
index,
124+
)
128125

126+
def load_documents(self, documents: list[VectorStoreDocument]) -> None:
127+
"""Load documents into an Azure AI Search index."""
129128
batch = [
130129
{
131130
self.id_field: doc.id,

graphrag/vector_stores/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def connect(self, **kwargs: Any) -> None:
5858
"""Connect to vector storage."""
5959

6060
@abstractmethod
61-
def load_documents(
62-
self, documents: list[VectorStoreDocument], overwrite: bool = True
63-
) -> None:
61+
def create_index(self) -> None:
62+
"""Create index."""
63+
64+
@abstractmethod
65+
def load_documents(self, documents: list[VectorStoreDocument]) -> None:
6466
"""Load documents into the vector-store."""
6567

6668
@abstractmethod

graphrag/vector_stores/cosmosdb.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,19 +149,18 @@ def _container_exists(self) -> bool:
149149
]
150150
return self._container_name in existing_container_names
151151

152-
def load_documents(
153-
self, documents: list[VectorStoreDocument], overwrite: bool = True
154-
) -> None:
152+
def create_index(self) -> None:
155153
"""Load documents into CosmosDB."""
156154
# Create a CosmosDB container on overwrite
157-
if overwrite:
158-
self._delete_container()
159-
self._create_container()
155+
self._delete_container()
156+
self._create_container()
160157

161158
if self._container_client is None:
162159
msg = "Container client is not initialized."
163160
raise ValueError(msg)
164161

162+
def load_documents(self, documents: list[VectorStoreDocument]) -> None:
163+
"""Load documents into CosmosDB."""
165164
# Upload documents to CosmosDB
166165
for doc in documents:
167166
if doc.vector is not None:

graphrag/vector_stores/lancedb.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,33 @@ def connect(self, **kwargs: Any) -> Any:
3535
if self.index_name and self.index_name in self.db_connection.table_names():
3636
self.document_collection = self.db_connection.open_table(self.index_name)
3737

38-
def load_documents(
39-
self, documents: list[VectorStoreDocument], overwrite: bool = True
40-
) -> None:
38+
def create_index(self) -> None:
39+
"""Create index."""
40+
dummy_vector = np.zeros(self.vector_size, dtype=np.float32)
41+
flat_array = pa.array(dummy_vector, type=pa.float32())
42+
vector_column = pa.FixedSizeListArray.from_arrays(flat_array, self.vector_size)
43+
44+
data = pa.table({
45+
self.id_field: pa.array(["__DUMMY__"], type=pa.string()),
46+
self.vector_field: vector_column,
47+
})
48+
49+
self.document_collection = self.db_connection.create_table(
50+
self.index_name if self.index_name else "",
51+
data=data,
52+
mode="overwrite",
53+
schema=data.schema,
54+
)
55+
56+
# Step 5: Create index now that schema exists
57+
self.document_collection.create_index(
58+
vector_column_name=self.vector_field, index_type="IVF_FLAT"
59+
)
60+
61+
def load_documents(self, documents: list[VectorStoreDocument]) -> None:
4162
"""Load documents into vector storage."""
63+
self.document_collection.delete(f"{self.id_field} = '__DUMMY__'")
64+
4265
# Step 1: Prepare data columns manually
4366
ids = []
4467
vectors = []
@@ -68,31 +91,13 @@ def load_documents(
6891
self.vector_field: vector_column,
6992
})
7093

71-
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
72-
# The pyarrow format of the 'vector' field may change if the order of operations is changed
73-
# and will break vector search.
74-
if overwrite:
7594
if data:
7695
self.document_collection = self.db_connection.create_table(
7796
self.index_name if self.index_name else "",
7897
data=data,
7998
mode="overwrite",
8099
schema=data.schema,
81100
)
82-
else:
83-
self.document_collection = self.db_connection.create_table(
84-
self.index_name if self.index_name else "", mode="overwrite"
85-
)
86-
self.document_collection.create_index(
87-
vector_column_name=self.vector_field, index_type="IVF_FLAT"
88-
)
89-
else:
90-
# add data to existing table
91-
self.document_collection = self.db_connection.open_table(
92-
self.index_name if self.index_name else ""
93-
)
94-
if data:
95-
self.document_collection.add(data)
96101

97102
def similarity_search_by_vector(
98103
self, query_embedding: list[float] | np.ndarray, k: int = 10

tests/integration/vector_stores/test_azure_ai_search.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ async def test_vector_store_operations(
120120
"vector": [0.1, 0.2, 0.3, 0.4, 0.5],
121121
}
122122

123+
vector_store.create_index()
123124
vector_store.load_documents(sample_documents)
124125
assert mock_index_client.create_or_update_index.called
125126
assert mock_search_client.upload_documents.called
@@ -188,6 +189,7 @@ async def test_vector_store_customization(
188189
vector_store_custom.vector_field: [0.1, 0.2, 0.3, 0.4, 0.5],
189190
}
190191

192+
vector_store_custom.create_index()
191193
vector_store_custom.load_documents(sample_documents)
192194
assert mock_index_client.create_or_update_index.called
193195
assert mock_search_client.upload_documents.called

tests/integration/vector_stores/test_cosmosdb.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def test_vector_store_operations():
4444
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
4545
),
4646
]
47+
48+
vector_store.create_index()
4749
vector_store.load_documents(docs)
4850

4951
doc = vector_store.search_by_id("doc1")
@@ -84,6 +86,7 @@ def test_clear():
8486
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
8587
)
8688

89+
vector_store.create_index()
8790
vector_store.load_documents([doc])
8891
result = vector_store.search_by_id("test")
8992
assert result.id == "test"
@@ -122,6 +125,8 @@ def test_vector_store_customization():
122125
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
123126
),
124127
]
128+
129+
vector_store.create_index()
125130
vector_store.load_documents(docs)
126131

127132
doc = vector_store.search_by_id("doc1")

tests/integration/vector_stores/test_factory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,10 @@ def __init__(self, **kwargs):
131131
def connect(self, **kwargs):
132132
pass
133133

134-
def load_documents(self, documents, overwrite=True):
134+
def create_index(self, **kwargs):
135+
pass
136+
137+
def load_documents(self, documents):
135138
pass
136139

137140
def similarity_search_by_vector(self, query_embedding, k=10, **kwargs):

tests/integration/vector_stores/test_lancedb.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def test_vector_store_operations(self, sample_documents):
6464
)
6565
)
6666
vector_store.connect(db_uri=temp_dir)
67+
vector_store.create_index()
6768
vector_store.load_documents(sample_documents[:2])
6869

6970
if vector_store.index_name:
@@ -83,7 +84,8 @@ def test_vector_store_operations(self, sample_documents):
8384
assert isinstance(results[0].score, float)
8485

8586
# Test append mode
86-
vector_store.load_documents([sample_documents[2]], overwrite=False)
87+
vector_store.create_index()
88+
vector_store.load_documents([sample_documents[2]])
8789
result = vector_store.search_by_id("3")
8890
assert result.id == "3"
8991

@@ -121,6 +123,7 @@ def test_empty_collection(self):
121123
id="tmp",
122124
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
123125
)
126+
vector_store.create_index()
124127
vector_store.load_documents([sample_doc])
125128
vector_store.db_connection.open_table(
126129
vector_store.index_name if vector_store.index_name else ""
@@ -137,7 +140,8 @@ def test_empty_collection(self):
137140
id="1",
138141
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
139142
)
140-
vector_store.load_documents([doc], overwrite=False)
143+
vector_store.create_index()
144+
vector_store.load_documents([doc])
141145

142146
result = vector_store.search_by_id("1")
143147
assert result.id == "1"
@@ -157,7 +161,7 @@ def test_filter_search(self, sample_documents_categories):
157161
)
158162

159163
vector_store.connect(db_uri=temp_dir)
160-
164+
vector_store.create_index()
161165
vector_store.load_documents(sample_documents_categories)
162166

163167
# Filter to include only documents about animals
@@ -186,6 +190,7 @@ def test_vector_store_customization(self, sample_documents):
186190
),
187191
)
188192
vector_store.connect(db_uri=temp_dir)
193+
vector_store.create_index()
189194
vector_store.load_documents(sample_documents[:2])
190195

191196
if vector_store.index_name:
@@ -205,7 +210,8 @@ def test_vector_store_customization(self, sample_documents):
205210
assert isinstance(results[0].score, float)
206211

207212
# Test append mode
208-
vector_store.load_documents([sample_documents[2]], overwrite=False)
213+
vector_store.create_index()
214+
vector_store.load_documents([sample_documents[2]])
209215
result = vector_store.search_by_id("3")
210216
assert result.id == "3"
211217

0 commit comments

Comments
 (0)