Skip to content

Commit 712f703

Browse files
committed
update tests
1 parent f4a6c79 commit 712f703

File tree

7 files changed

+351
-31
lines changed

7 files changed

+351
-31
lines changed

pymongo_voyageai/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from ._version import __version__
22
from .client import PyMongoVoyageAI
33
from .document import Document, DocumentType, ImageDocument, StoredDocument, TextDocument
4-
from .storage import ObjectStorage, S3Storage
4+
from .storage import MemoryStorage, ObjectStorage, S3Storage
55

66
__all__ = [
77
"Document",
@@ -12,5 +12,6 @@
1212
"PyMongoVoyageAI",
1313
"ObjectStorage",
1414
"S3Storage",
15+
"MemoryStorage",
1516
"__version__",
1617
]

pymongo_voyageai/client.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any
77

88
from bson import ObjectId
9+
from langchain_core.runnables.config import run_in_executor
910
from langchain_mongodb.index import create_vector_search_index
1011
from langchain_mongodb.pipelines import vector_search_stage
1112
from langchain_mongodb.utils import make_serializable
@@ -105,6 +106,17 @@ def image_to_storage(self, document: ImageDocument | Image.Image) -> StoredDocum
105106
document = ImageDocument(image=document)
106107
return self._storage.save_image(document)
107108

109+
async def aimage_to_storage(self, document: ImageDocument | Image.Image) -> StoredDocument:
110+
"""Convert an image to a stored document.
111+
112+
Args:
113+
document: The input document or image object.
114+
115+
Returns:
116+
The stored document object.
117+
"""
118+
return await run_in_executor(None, self.image_to_storage, document)
119+
108120
def storage_to_image(self, document: StoredDocument | str) -> ImageDocument:
109121
"""Convert a stored document to an image document.
110122
@@ -120,6 +132,17 @@ def storage_to_image(self, document: StoredDocument | str) -> ImageDocument:
120132
)
121133
return self._storage.load_image(document=document)
122134

135+
async def astorage_to_image(self, document: StoredDocument | str) -> ImageDocument:
136+
"""Convert a stored document to an image document.
137+
138+
Args:
139+
document: The input document or object name.
140+
141+
Returns:
142+
The image document object.
143+
"""
144+
return await run_in_executor(None, self.storage_to_image, document)
145+
123146
def url_to_images(
124147
self,
125148
url: str,
@@ -145,6 +168,38 @@ def url_to_images(
145168
url, metadata=metadata, start=start, end=end, image_column=image_column, **kwargs
146169
)
147170

171+
async def aurl_to_images(
172+
self,
173+
url: str,
174+
metadata: dict[str, Any] | None = None,
175+
start: int = 0,
176+
end: int | None = None,
177+
image_column: str | None = None,
178+
**kwargs: Any,
179+
) -> list[ImageDocument]:
180+
"""Extract images from a url.
181+
182+
Args:
183+
url: The url to load the images from.
184+
metadata: A set of metadata to associate with the images.
185+
start: The start frame to use for the images.
186+
end: The end frame to use for the images.
187+
image_column: The name of the column used to store the image data, for parquet files.
188+
189+
Returns:
190+
A list of image document objects.
191+
"""
192+
return await run_in_executor(
193+
None,
194+
self.url_to_images,
195+
url,
196+
metadata=metadata,
197+
start=start,
198+
end=end,
199+
image_column=image_column,
200+
**kwargs,
201+
)
202+
148203
def add_documents(
149204
self,
150205
inputs: Sequence[str | Image.Image | Document | Sequence[str | Image.Image | Document]],
@@ -230,6 +285,30 @@ def add_documents(
230285
self._coll.bulk_write(operations)
231286
return output_docs
232287

288+
async def aadd_documents(
289+
self,
290+
inputs: Sequence[str | Image.Image | Document | Sequence[str | Image.Image | Document]],
291+
ids: list[str] | None = None,
292+
batch_size: int = DEFAULT_INSERT_BATCH_SIZE,
293+
**kwargs: Any,
294+
) -> list[dict[str, Any]]:
295+
"""Add multimodal documents to the vectorstore.
296+
297+
Args:
298+
inputs: List of inputs to add to the vectorstore, which are each a list of documents.
299+
ids: Optional list of unique ids that will be used as index in VectorStore.
300+
See note on ids in add_texts.
301+
batch_size: Number of documents to insert at a time.
302+
Tuning this may help with performance and sidestep MongoDB limits.
303+
kwargs: Additional keyword args for future expansion.
304+
305+
Returns:
306+
A list documents with their associated input documents.
307+
"""
308+
return await run_in_executor(
309+
None, self.add_documents, inputs, ids=ids, batch_size=batch_size, **kwargs
310+
)
311+
233312
def delete_by_ids(
234313
self, ids: list[str | ObjectId], delete_stored_objects: bool = True, **kwargs: Any
235314
) -> bool:
@@ -248,6 +327,23 @@ def delete_by_ids(
248327
{"_id": {"$in": oids}}, delete_stored_objects=delete_stored_objects, **kwargs
249328
)
250329

330+
async def adelete_by_ids(
331+
self, ids: list[str | ObjectId], delete_stored_objects: bool = True, **kwargs: Any
332+
) -> bool:
333+
"""Delete documents by ids.
334+
335+
Args:
336+
ids: List of ids to delete.
337+
delete_stored_objects: Whether to delete the associated stored objects.
338+
**kwargs: Other keyword arguments passed to delete_many().
339+
340+
Returns:
341+
bool: True if deletion is successful, False otherwise.
342+
"""
343+
return await run_in_executor(
344+
None, self.delete_by_ids, ids, delete_stored_objects=delete_stored_objects, **kwargs
345+
)
346+
251347
def delete_many(
252348
self, filter: Mapping[str, Any], delete_stored_objects: bool = True, **kwargs: Any
253349
) -> bool:
@@ -269,11 +365,32 @@ def delete_many(
269365
self._storage.delete_image(inp)
270366
return self._coll.delete_many(filter=filter, **kwargs).acknowledged
271367

368+
async def adelete_many(
369+
self, filter: Mapping[str, Any], delete_stored_objects: bool = True, **kwargs: Any
370+
) -> bool:
371+
"""Delete documents using a filter.
372+
373+
Args:
374+
ids: List of ids to delete.
375+
delete_stored_objects: Whether to delete the associated stored objects.
376+
**kwargs: Other keyword arguments passed to the collection's `delete_many` method.
377+
378+
Returns:
379+
bool: True if deletion is successful, False otherwise.
380+
"""
381+
return await run_in_executor(
382+
None, self.delete_many, filter, delete_stored_objects=delete_stored_objects, **kwargs
383+
)
384+
272385
def close(self) -> None:
273386
"""Close the client, cleaning up resources."""
274387
self._coll.database.client.close()
275388
self._storage.close()
276389

390+
async def aclose(self) -> None:
391+
"""Close the client, cleaning up resources."""
392+
return await run_in_executor(None, self.close)
393+
277394
def get_by_ids(
278395
self, ids: Sequence[str | ObjectId], extract_images: bool = True
279396
) -> list[dict[str, Any]]:
@@ -294,6 +411,21 @@ def get_by_ids(
294411
docs.append(doc)
295412
return docs
296413

414+
async def aget_by_ids(
415+
self, ids: Sequence[str | ObjectId], extract_images: bool = True
416+
) -> list[dict[str, Any]]:
417+
"""Get a list of documents by id.
418+
419+
Args:
420+
ids: List of ids to search for.
421+
extract_images: Whether to extract the stored documents into image documents.
422+
423+
Returns:
424+
A list of matching documents, where the `inputs` is a list of stored documents
425+
or image documents.
426+
"""
427+
return await run_in_executor(None, self.get_by_ids, ids, extract_images=extract_images)
428+
297429
def wait_for_indexing(self, timeout: int = TIMEOUT, interval: int = INTERVAL) -> None:
298430
"""Wait for the search index to update to account for newly added embeddings."""
299431
n_docs = self._coll.count_documents({})
@@ -306,6 +438,12 @@ def wait_for_indexing(self, timeout: int = TIMEOUT, interval: int = INTERVAL) ->
306438

307439
raise TimeoutError(f"Failed to embed, insert, and index texts in {timeout}s.")
308440

441+
async def await_for_indexing(self, timeout: int = TIMEOUT, interval: int = INTERVAL) -> None:
442+
"""Wait for the search index to update to account for newly added embeddings."""
443+
return await run_in_executor(
444+
None, self.wait_for_indexing, timeout=timeout, interval=interval
445+
)
446+
309447
def similarity_search(
310448
self,
311449
query: str,
@@ -379,6 +517,53 @@ def similarity_search(
379517
docs.append(res)
380518
return docs
381519

520+
async def asimilarity_search(
521+
self,
522+
query: str,
523+
k: int = 4,
524+
pre_filter: dict[str, Any] | None = None,
525+
post_filter_pipeline: list[dict[str, Any]] | None = None,
526+
oversampling_factor: int = 10,
527+
include_scores: bool = False,
528+
include_embeddings: bool = False,
529+
extract_images: bool = False,
530+
**kwargs: Any,
531+
) -> list[dict[str, Any]]: # noqa: E501
532+
"""Return documents most similar to the given query.
533+
534+
Args:
535+
query: Input text of semantic query.
536+
k: The number of documents to return. Defaults to 4.
537+
pre_filter: List of MQL match expressions comparing an indexed field.
538+
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
539+
to filter/process results after $vectorSearch.
540+
oversampling_factor: Multiple of k used when generating number of candidates
541+
at each step in the HNSW Vector Search.
542+
include_scores: If True, the query score of each result
543+
will be included in metadata.
544+
include_embeddings: If True, the embedding vector of each result
545+
will be included in metadata.
546+
extract_images: If True, the stored documents will be converted image documents.
547+
kwargs: Additional arguments are specific to the search_type
548+
549+
Returns:
550+
List of documents most similar to the query and their scores, where the `inputs`
551+
is a list of stored documents or image documents.
552+
"""
553+
return await run_in_executor(
554+
None,
555+
self.similarity_search,
556+
query,
557+
k=k,
558+
pre_filter=pre_filter,
559+
post_filter_pipeline=post_filter_pipeline,
560+
oversampling_factor=oversampling_factor,
561+
include_scores=include_scores,
562+
include_embeddings=include_embeddings,
563+
extract_images=extract_images,
564+
**kwargs,
565+
)
566+
382567
def _expand_doc(self, obj: dict[str, Any], extract_images: bool = True) -> dict[str, Any]:
383568
for idx, inp in enumerate(list(obj["inputs"])):
384569
if inp["type"] == DocumentType.storage:

pymongo_voyageai/storage.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,31 @@ def delete_image(self, document: StoredDocument) -> None:
8282

8383
def close(self) -> None:
8484
self.client.close()
85+
86+
87+
class MemoryStorage(ObjectStorage):
88+
"""An in-memory object store"""
89+
90+
def __init__(self) -> None:
91+
self.root_location = "foo"
92+
self.storage: dict[str, ImageDocument] = dict()
93+
94+
def save_image(self, image: ImageDocument) -> StoredDocument:
95+
object_name = str(ObjectId())
96+
self.storage[object_name] = image
97+
return StoredDocument(
98+
root_location=self.root_location,
99+
name=image.name,
100+
object_name=object_name,
101+
source_url=image.source_url,
102+
page_number=image.page_number,
103+
)
104+
105+
def load_image(self, document: StoredDocument) -> ImageDocument:
106+
return self.storage[document.object_name]
107+
108+
def delete_image(self, document: StoredDocument) -> None:
109+
self.storage.pop(document.object_name, None)
110+
111+
def close(self):
112+
pass

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ dev = [
5151
"pyarrow>=19.0.1",
5252
"pre-commit>=4.2.0",
5353
"autodoc-pydantic>=2.2.0",
54+
"pytest-asyncio>=0.26.0",
5455
]
5556

5657

tests/test_client.py renamed to tests/test_client_integration.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,50 +3,25 @@
33

44
import numpy as np
55
import pytest
6-
from bson import ObjectId
76

8-
from pymongo_voyageai import ImageDocument, PyMongoVoyageAI, StoredDocument
9-
from pymongo_voyageai.storage import ImageStorage, S3Storage
7+
from pymongo_voyageai import PyMongoVoyageAI
108

119
if "VOYAGEAI_API_KEY" not in os.environ:
1210
pytest.skip("Requires VoyageAI API Key.", allow_module_level=True)
1311

12+
if "S3_BUCKET_NAME" not in os.environ:
13+
pytest.skip("Requires VoyageAI API Key.", allow_module_level=True)
1414

1515
# mypy: disable_error_code="no-untyped-def"
16-
class MemoryStorage(ImageStorage):
17-
def __init__(self) -> None:
18-
self.root_location = "foo"
19-
self.storage: dict[str, ImageDocument] = dict()
20-
21-
def save_image(self, image: ImageDocument) -> StoredDocument:
22-
object_name = str(ObjectId())
23-
self.storage[object_name] = image
24-
return StoredDocument(
25-
root_location=self.root_location,
26-
name=image.name,
27-
object_name=object_name,
28-
source_url=image.source_url,
29-
page_number=image.page_number,
30-
)
31-
32-
def load_image(self, document: StoredDocument) -> ImageDocument:
33-
return self.storage[document.object_name]
34-
35-
def delete_image(self, document: StoredDocument) -> None:
36-
del self.storage[document.object_name]
3716

3817

3918
@pytest.fixture
4019
def client() -> Generator[PyMongoVoyageAI, None, None]:
4120
conn_str = os.environ.get("MONGODB_URI", "mongodb://127.0.0.1:27017?directConnection=true")
42-
if "S3_BUCKET" in os.environ:
43-
storage_object = S3Storage(os.environ["S3_BUCKET"])
44-
else:
45-
storage_object = MemoryStorage() # type:ignore[assignment]
4621
client = PyMongoVoyageAI(
4722
voyageai_api_key=os.environ["VOYAGEAI_API_KEY"],
23+
s3_bucket_name=os.environ["S3_BUCKET_NAME"],
4824
mongo_connection_string=conn_str,
49-
storage_object=storage_object,
5025
collection_name="test",
5126
database_name="pymongo_voyageai_test_db",
5227
)
@@ -93,8 +68,25 @@ def test_pdf_pages(client: PyMongoVoyageAI):
9368
images = client.url_to_images(url)
9469
resp = client.add_documents(images)
9570
client.wait_for_indexing()
96-
data = client.similarity_search(query, extract_images=True)
71+
data = client.similarity_search(query, extract_images=False)
9772
# We expect page 5 to be the best match.
9873
assert data[0]["inputs"][0].page_number == 5
9974
assert len(client.get_by_ids([d["_id"] for d in resp])) == len(resp)
10075
client.delete_by_ids([d["_id"] for d in resp])
76+
77+
78+
@pytest.mark.asyncio
79+
async def test_image_set_async(client: PyMongoVoyageAI):
80+
url = "hf://datasets/princeton-nlp/CharXiv/val.parquet"
81+
documents = await client.aurl_to_images(url, image_column="image", end=3)
82+
resp = await client.aadd_documents(documents)
83+
await client.await_for_indexing()
84+
query = "3D loss landscapes for different training strategies"
85+
data = await client.asimilarity_search(query, extract_images=True)
86+
# The best match should be the third input image.
87+
assert data[0]["inputs"][0].image.tobytes() == documents[2].image.tobytes()
88+
ids = await client.aget_by_ids([d["_id"] for d in resp])
89+
assert len(ids) == len(resp)
90+
await client.adelete_by_ids([d["_id"] for d in resp])
91+
await client.adelete_many({})
92+
await client.aclose()

0 commit comments

Comments
 (0)