Skip to content

Commit 1fe6aeb

Browse files
committed
feat(weaviate): add tenant support in write_documents and async batch write
1 parent 7b85813 commit 1fe6aeb

File tree

2 files changed

+35
-44
lines changed

2 files changed

+35
-44
lines changed

integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,11 @@ def _clean_connection_settings(self) -> None:
177177
_class_name = self._collection_settings.get("class", "Default")
178178
_class_name = _class_name[0].upper() + _class_name[1:]
179179
self._collection_settings["class"] = _class_name
180+
180181
# Set the properties if they're not set
181182
self._collection_settings["properties"] = self._collection_settings.get(
182183
"properties", DOCUMENT_COLLECTION_PROPERTIES
183-
)
184+
)
184185

185186
@property
186187
def client(self) -> weaviate.WeaviateClient:
@@ -937,11 +938,13 @@ def _handle_failed_objects(failed_objects: list[ErrorObject]) -> NoReturn:
937938
def _batch_write(self, documents: list[Document], tenant: str | None = None) -> int:
938939
"""
939940
Writes document to Weaviate in batches.
940-
941-
Documents with the same id will be overwritten.
942-
Raises in case of errors.
943941
"""
944942

943+
# Handle tenant at collection level (NOT via kwargs)
944+
collection = self.collection
945+
if tenant is not None:
946+
collection = collection.with_tenant(tenant)
947+
945948
with self.client.batch.dynamic() as batch:
946949
for doc in documents:
947950
if not isinstance(doc, Document):
@@ -950,27 +953,28 @@ def _batch_write(self, documents: list[Document], tenant: str | None = None) ->
950953

951954
batch.add_object(
952955
properties=WeaviateDocumentStore._to_data_object(doc),
953-
collection=self.collection.name,
956+
collection=collection.name,
954957
uuid=generate_uuid5(doc.id),
955958
vector=doc.embedding,
956-
tenant=tenant
957959
)
960+
958961
if failed_objects := self.client.batch.failed_objects:
959962
self._handle_failed_objects(failed_objects)
960963

961-
# If the document already exists we get no status message back from Weaviate.
962-
# So we assume that all Documents were written.
963964
return len(documents)
964965

965966
async def _batch_write_async(self, documents: list[Document], tenant: str | None = None) -> int:
966967
"""
967968
Asynchronously writes document to Weaviate in batches.
968-
969-
Documents with the same id will be overwritten.
970-
Raises in case of errors.
971969
"""
970+
972971
client = await self.async_client
973972

973+
# Handle tenant properly
974+
collection = await self.async_collection
975+
if tenant is not None:
976+
collection = collection.with_tenant(tenant)
977+
974978
async with client.batch.stream() as batch:
975979
for doc in documents:
976980
if not isinstance(doc, Document):
@@ -979,17 +983,14 @@ async def _batch_write_async(self, documents: list[Document], tenant: str | None
979983

980984
await batch.add_object(
981985
properties=WeaviateDocumentStore._to_data_object(doc),
982-
collection=(await self.async_collection).name,
986+
collection=collection.name,
983987
uuid=generate_uuid5(doc.id),
984988
vector=doc.embedding,
985-
tenant=tenant
986989
)
987990

988991
if failed_objects := client.batch.failed_objects:
989992
self._handle_failed_objects(failed_objects)
990993

991-
# If the document already exists we get no status message back from Weaviate.
992-
# So we assume that all Documents were written.
993994
return len(documents)
994995

995996
def _write(self, documents: list[Document], policy: DuplicatePolicy, tenant: str | None = None) -> int:
@@ -1260,25 +1261,23 @@ async def delete_all_documents_async(self, *, recreate_index: bool = False, batc
12601261
)
12611262

12621263
def delete_by_filter(self, filters: dict[str, Any]) -> int:
1263-
"""
1264-
Deletes all documents that match the provided filters.
1265-
1266-
:param filters: The filters to apply to select documents for deletion.
1267-
For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering)
1268-
:returns: The number of documents deleted.
1269-
"""
12701264
validate_filters(filters)
12711265

12721266
try:
1267+
collection = self.collection # ✅ FIX
1268+
12731269
weaviate_filter = convert_filters(filters)
1274-
result = self.collection.data.delete_many(where=weaviate_filter)
1270+
result = collection.data.delete_many(where=weaviate_filter)
12751271
deleted_count = result.successful
1272+
12761273
logger.info(
12771274
"Deleted {n_docs} documents from collection '{collection}' using filters.",
12781275
n_docs=deleted_count,
1279-
collection=self.collection.name,
1276+
collection=collection.name,
12801277
)
1278+
12811279
return deleted_count
1280+
12821281
except weaviate.exceptions.WeaviateQueryError as e:
12831282
msg = f"Failed to delete documents by filter in Weaviate. Error: {e.message}"
12841283
raise DocumentStoreError(msg) from e
@@ -1315,54 +1314,42 @@ async def delete_by_filter_async(self, filters: dict[str, Any]) -> int:
13151314
raise DocumentStoreError(msg) from e
13161315

13171316
def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int:
1318-
"""
1319-
Updates the metadata of all documents that match the provided filters.
1320-
1321-
:param filters: The filters to apply to select documents for updating.
1322-
For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering)
1323-
:param meta: The metadata fields to update. These will be merged with existing metadata.
1324-
:returns: The number of documents updated.
1325-
"""
13261317
validate_filters(filters)
13271318

13281319
if not isinstance(meta, dict):
13291320
msg = "Meta must be a dictionary"
13301321
raise ValueError(msg)
13311322

13321323
try:
1324+
collection = self.collection # ✅ FIX
1325+
13331326
matching_objects = self._query_with_filters(filters)
13341327
if not matching_objects:
13351328
return 0
13361329

1337-
# Update each object with the new metadata
1338-
# Since metadata is stored flattened in Weaviate properties, we update properties directly
13391330
updated_count = 0
13401331
failed_updates = []
13411332

13421333
for obj in matching_objects:
13431334
try:
1344-
# Get current properties
13451335
current_properties = obj.properties.copy() if obj.properties else {}
13461336

1347-
# Update with new metadata values
1348-
# Note: metadata fields are stored directly in properties (flattened)
13491337
for key, value in meta.items():
13501338
current_properties[key] = value
13511339

1352-
# Update the object, preserving the vector
1353-
# Get the vector from the object to preserve it during replace
13541340
vector: VECTORS | None = None
13551341
if isinstance(obj.vector, (list, dict)):
13561342
vector = obj.vector
13571343

1358-
self.collection.data.replace(
1344+
collection.data.replace( # ✅ FIX
13591345
uuid=obj.uuid,
13601346
properties=current_properties,
13611347
vector=vector,
13621348
)
1349+
13631350
updated_count += 1
1351+
13641352
except Exception as e:
1365-
# Collect failed updates but continue with others
13661353
obj_properties = obj.properties or {}
13671354
id_ = obj_properties.get("_original_id", obj.uuid)
13681355
failed_updates.append((id_, str(e)))
@@ -1376,9 +1363,11 @@ def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int
13761363
logger.info(
13771364
"Updated {n_docs} documents in collection '{collection}' using filters.",
13781365
n_docs=updated_count,
1379-
collection=self.collection.name,
1366+
collection=collection.name,
13801367
)
1368+
13811369
return updated_count
1370+
13821371
except weaviate.exceptions.WeaviateQueryError as e:
13831372
msg = f"Failed to update documents by filter in Weaviate. Error: {e.message}"
13841373
raise DocumentStoreError(msg) from e
@@ -1618,4 +1607,4 @@ async def _hybrid_retrieval_async(
16181607
return_metadata=["score"],
16191608
)
16201609

1621-
return [WeaviateDocumentStore._to_document(doc) for doc in result.objects]
1610+
return [WeaviateDocumentStore._to_document(doc) for doc in result.objects]

integrations/weaviate/tests/test_document_store.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Generator
99
from unittest.mock import MagicMock, patch
1010

11+
import platform
1112
import pytest
1213
from dateutil import parser
1314
from haystack.dataclasses.byte_stream import ByteStream
@@ -833,6 +834,7 @@ def test_connect_to_local(self):
833834
)
834835
assert document_store.client
835836

837+
@pytest.mark.skipif(platform.system() == "Windows", reason="EmbeddedDB not supported on Windows")
836838
def test_connect_to_embedded(self):
837839
document_store = WeaviateDocumentStore(embedded_options=EmbeddedOptions())
838840
assert document_store.client
@@ -1142,4 +1144,4 @@ def test_count_unique_metadata_by_filter_all_documents(document_store):
11421144
)
11431145
assert counts["category"] == 3
11441146
assert counts["status"] == 2
1145-
assert counts["priority"] == 3
1147+
assert counts["priority"] == 3

0 commit comments

Comments
 (0)