Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def _clean_connection_settings(self) -> None:
_class_name = self._collection_settings.get("class", "Default")
_class_name = _class_name[0].upper() + _class_name[1:]
self._collection_settings["class"] = _class_name

# Set the properties if they're not set
self._collection_settings["properties"] = self._collection_settings.get(
"properties", DOCUMENT_COLLECTION_PROPERTIES
Expand Down Expand Up @@ -934,14 +935,16 @@ def _handle_failed_objects(failed_objects: list[ErrorObject]) -> NoReturn:
)
raise DocumentStoreError(msg)

def _batch_write(self, documents: list[Document]) -> int:
def _batch_write(self, documents: list[Document], tenant: str | None = None) -> int:
"""
Writes document to Weaviate in batches.

Documents with the same id will be overwritten.
Raises in case of errors.
"""

# Handle tenant at collection level (NOT via kwargs)
collection = self.collection
if tenant is not None:
collection = collection.with_tenant(tenant)

with self.client.batch.dynamic() as batch:
for doc in documents:
if not isinstance(doc, Document):
Expand All @@ -950,26 +953,28 @@ def _batch_write(self, documents: list[Document]) -> int:

batch.add_object(
properties=WeaviateDocumentStore._to_data_object(doc),
collection=self.collection.name,
collection=collection.name,
uuid=generate_uuid5(doc.id),
vector=doc.embedding,
)

if failed_objects := self.client.batch.failed_objects:
self._handle_failed_objects(failed_objects)

# If the document already exists we get no status message back from Weaviate.
# So we assume that all Documents were written.
return len(documents)

async def _batch_write_async(self, documents: list[Document]) -> int:
async def _batch_write_async(self, documents: list[Document], tenant: str | None = None) -> int:
"""
Asynchronously writes document to Weaviate in batches.

Documents with the same id will be overwritten.
Raises in case of errors.
"""

client = await self.async_client

# Handle tenant properly
collection = await self.async_collection
if tenant is not None:
collection = collection.with_tenant(tenant)

async with client.batch.stream() as batch:
for doc in documents:
if not isinstance(doc, Document):
Expand All @@ -978,39 +983,40 @@ async def _batch_write_async(self, documents: list[Document]) -> int:

await batch.add_object(
properties=WeaviateDocumentStore._to_data_object(doc),
collection=(await self.async_collection).name,
collection=collection.name,
uuid=generate_uuid5(doc.id),
vector=doc.embedding,
)

if failed_objects := client.batch.failed_objects:
self._handle_failed_objects(failed_objects)

# If the document already exists we get no status message back from Weaviate.
# So we assume that all Documents were written.
return len(documents)

def _write(self, documents: list[Document], policy: DuplicatePolicy) -> int:
def _write(self, documents: list[Document], policy: DuplicatePolicy, tenant: str | None = None) -> int:
"""
Writes documents to Weaviate using the specified policy.

This doesn't use the batch API, so it's slower than _batch_write.
If policy is set to SKIP it will skip any document that already exists.
If policy is set to FAIL it will raise an exception if any of the documents already exists.
"""
collection = self.collection
if tenant:
collection = collection.with_tenant(tenant)
written = 0
duplicate_errors_ids = []
for doc in documents:
if not isinstance(doc, Document):
msg = f"Expected a Document, got '{type(doc)}' instead."
raise ValueError(msg)

if policy == DuplicatePolicy.SKIP and self.collection.data.exists(uuid=generate_uuid5(doc.id)):
if policy == DuplicatePolicy.SKIP and collection.data.exists(uuid=generate_uuid5(doc.id)):
# This Document already exists, we skip it
continue

try:
self.collection.data.insert(
collection.data.insert(
uuid=generate_uuid5(doc.id),
properties=WeaviateDocumentStore._to_data_object(doc),
vector=doc.embedding,
Expand All @@ -1025,7 +1031,12 @@ def _write(self, documents: list[Document], policy: DuplicatePolicy) -> int:
raise DuplicateDocumentError(msg)
return written

async def _write_async(self, documents: list[Document], policy: DuplicatePolicy) -> int:
async def _write_async(
self,
documents: list[Document],
policy: DuplicatePolicy,
tenant: str | None = None,
) -> int:
"""
Asynchronously writes documents to Weaviate using the specified policy.

Expand All @@ -1034,16 +1045,15 @@ async def _write_async(self, documents: list[Document], policy: DuplicatePolicy)
If policy is set to FAIL it will raise an exception if any of the documents already exists.
"""
collection = await self.async_collection

if tenant:
collection = collection.with_tenant(tenant)
duplicate_errors_ids = []
for doc in documents:
if not isinstance(doc, Document):
msg = f"Expected a Document, got '{type(doc)}' instead."
raise ValueError(msg)

if policy == DuplicatePolicy.SKIP and await (await self.async_collection).data.exists(
uuid=generate_uuid5(doc.id)
):
if policy == DuplicatePolicy.SKIP and await collection.data.exists(uuid=generate_uuid5(doc.id)):
# This Document already exists, continue
continue

Expand All @@ -1063,7 +1073,12 @@ async def _write_async(self, documents: list[Document], policy: DuplicatePolicy)
raise DuplicateDocumentError(msg)
return len(documents)

def write_documents(self, documents: list[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
def write_documents(
self,
documents: list[Document],
policy: DuplicatePolicy = DuplicatePolicy.NONE,
tenant: str | None = None,
) -> int:
"""
Writes documents to Weaviate using the specified policy.

Expand All @@ -1089,12 +1104,15 @@ def write_documents(self, documents: list[Document], policy: DuplicatePolicy = D
The number of documents written.
"""
if policy in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]:
return self._batch_write(documents)
return self._batch_write(documents, tenant)

return self._write(documents, policy)
return self._write(documents, policy, tenant)

async def write_documents_async(
self, documents: list[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE
self,
documents: list[Document],
policy: DuplicatePolicy = DuplicatePolicy.NONE,
tenant: str | None = None,
) -> int:
"""
Asynchronously writes documents to Weaviate using the specified policy.
Expand All @@ -1121,9 +1139,9 @@ async def write_documents_async(
The number of documents written.
"""
if policy in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]:
return await self._batch_write_async(documents)
return await self._batch_write_async(documents, tenant)

return await self._write_async(documents, policy)
return await self._write_async(documents, policy, tenant)

def delete_documents(self, document_ids: list[str]) -> None:
"""
Expand Down Expand Up @@ -1249,21 +1267,26 @@ def delete_by_filter(self, filters: dict[str, Any]) -> int:
Deletes all documents that match the provided filters.

:param filters: The filters to apply to select documents for deletion.
For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering)
For filter syntax, see Haystack metadata filtering docs.
:returns: The number of documents deleted.
"""
validate_filters(filters)

try:
collection = self.collection

weaviate_filter = convert_filters(filters)
result = self.collection.data.delete_many(where=weaviate_filter)
result = collection.data.delete_many(where=weaviate_filter)
deleted_count = result.successful

logger.info(
"Deleted {n_docs} documents from collection '{collection}' using filters.",
n_docs=deleted_count,
collection=self.collection.name,
collection=collection.name,
)

return deleted_count

except weaviate.exceptions.WeaviateQueryError as e:
msg = f"Failed to delete documents by filter in Weaviate. Error: {e.message}"
raise DocumentStoreError(msg) from e
Expand All @@ -1275,23 +1298,26 @@ async def delete_by_filter_async(self, filters: dict[str, Any]) -> int:
"""
Asynchronously deletes all documents that match the provided filters.

:param filters: The filters to apply to select documents for deletion.
For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering)
:returns: The number of documents deleted.
:param filters: Filters to select documents for deletion.
:returns: Number of deleted documents.
"""
validate_filters(filters)

try:
collection = await self.async_collection

weaviate_filter = convert_filters(filters)
result = await collection.data.delete_many(where=weaviate_filter)
deleted_count = result.successful

logger.info(
"Deleted {n_docs} documents from collection '{collection}' using filters.",
n_docs=deleted_count,
collection=collection.name,
)

return deleted_count

except weaviate.exceptions.WeaviateQueryError as e:
msg = f"Failed to delete documents by filter in Weaviate. Error: {e.message}"
raise DocumentStoreError(msg) from e
Expand All @@ -1301,12 +1327,11 @@ async def delete_by_filter_async(self, filters: dict[str, Any]) -> int:

def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int:
"""
Updates the metadata of all documents that match the provided filters.
Updates metadata of all documents that match the provided filters.

:param filters: The filters to apply to select documents for updating.
For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering)
:param meta: The metadata fields to update. These will be merged with existing metadata.
:returns: The number of documents updated.
:param filters: Filters to select documents for updating.
:param meta: Metadata fields to update.
:returns: Number of updated documents.
"""
validate_filters(filters)

Expand All @@ -1315,39 +1340,35 @@ def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int
raise ValueError(msg)

try:
collection = self.collection # ✅ FIX

matching_objects = self._query_with_filters(filters)
if not matching_objects:
return 0

# Update each object with the new metadata
# Since metadata is stored flattened in Weaviate properties, we update properties directly
updated_count = 0
failed_updates = []

for obj in matching_objects:
try:
# Get current properties
current_properties = obj.properties.copy() if obj.properties else {}

# Update with new metadata values
# Note: metadata fields are stored directly in properties (flattened)
for key, value in meta.items():
current_properties[key] = value

# Update the object, preserving the vector
# Get the vector from the object to preserve it during replace
vector: VECTORS | None = None
if isinstance(obj.vector, (list, dict)):
if isinstance(obj.vector, list | dict):
vector = obj.vector

self.collection.data.replace(
collection.data.replace( # ✅ FIX
uuid=obj.uuid,
properties=current_properties,
vector=vector,
)

updated_count += 1

except Exception as e:
# Collect failed updates but continue with others
obj_properties = obj.properties or {}
id_ = obj_properties.get("_original_id", obj.uuid)
failed_updates.append((id_, str(e)))
Expand All @@ -1361,9 +1382,11 @@ def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int
logger.info(
"Updated {n_docs} documents in collection '{collection}' using filters.",
n_docs=updated_count,
collection=self.collection.name,
collection=collection.name,
)

return updated_count

except weaviate.exceptions.WeaviateQueryError as e:
msg = f"Failed to update documents by filter in Weaviate. Error: {e.message}"
raise DocumentStoreError(msg) from e
Expand Down Expand Up @@ -1431,7 +1454,7 @@ async def update_by_filter_async(self, filters: dict[str, Any], meta: dict[str,
# Update the object, preserving the vector
# Get the vector from the object to preserve it during replace
vector: VECTORS | None = None
if isinstance(obj.vector, (list, dict)):
if isinstance(obj.vector, list | dict):
vector = obj.vector

await collection.data.replace(
Expand Down
11 changes: 11 additions & 0 deletions integrations/weaviate/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import base64
import logging
import os
import platform
from collections.abc import Generator
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -377,6 +378,15 @@ def test_write_documents(self, document_store):
assert document_store.write_documents([doc]) == 1
assert document_store.count_documents() == 1

def test_write_documents_with_tenant(self, document_store):
doc = Document(content="tenant test doc")

# Write with tenant
written = document_store.write_documents([doc], tenant="tenant1")

assert written == 1
assert document_store.count_documents() == 1

def test_write_documents_with_blob_data(self, document_store, test_files_path):
image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg")
doc = Document(content="test doc", blob=image)
Expand Down Expand Up @@ -824,6 +834,7 @@ def test_connect_to_local(self):
)
assert document_store.client

@pytest.mark.skipif(platform.system() == "Windows", reason="EmbeddedDB not supported on Windows")
def test_connect_to_embedded(self):
document_store = WeaviateDocumentStore(embedded_options=EmbeddedOptions())
assert document_store.client
Expand Down
12 changes: 12 additions & 0 deletions integrations/weaviate/tests/test_document_store_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ async def test_write_documents_async(self, document_store: WeaviateDocumentStore
assert await document_store.write_documents_async([doc]) == 1
assert await document_store.count_documents_async() == 1

@pytest.mark.asyncio
async def test_write_documents_with_tenant_async(self, document_store):
doc = Document(content="tenant test doc")

written = await document_store.write_documents_async([doc], tenant="tenant1")

assert written == 1

docs = await document_store.filter_documents_async()
assert len(docs) == 1
assert docs[0].content == "tenant test doc"

@pytest.mark.asyncio
async def test_write_documents_with_blob_data_async(
self, document_store: WeaviateDocumentStore, test_files_path: Path
Expand Down