diff --git a/.gitignore b/.gitignore index b8ca9d4..83847d4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea/ .DS_STORE __pycache__/ +.env diff --git a/pymongo_vectorsearch_utils/__init__.py b/pymongo_vectorsearch_utils/__init__.py index d363331..3027f7c 100644 --- a/pymongo_vectorsearch_utils/__init__.py +++ b/pymongo_vectorsearch_utils/__init__.py @@ -5,6 +5,7 @@ drop_vector_search_index, update_vector_search_index, ) +from .operation import bulk_embed_and_insert_texts from .pipeline import ( combine_pipelines, final_hybrid_stage, @@ -24,4 +25,5 @@ "combine_pipelines", "reciprocal_rank_stage", "final_hybrid_stage", + "bulk_embed_and_insert_texts", ] diff --git a/pymongo_vectorsearch_utils/operation.py b/pymongo_vectorsearch_utils/operation.py new file mode 100644 index 0000000..8e8df83 --- /dev/null +++ b/pymongo_vectorsearch_utils/operation.py @@ -0,0 +1,62 @@ +"""CRUD utilities and helpers.""" + +from collections.abc import Callable, Generator, Iterable +from typing import Any + +from bson import ObjectId +from pymongo import ReplaceOne +from pymongo.synchronous.collection import Collection + +from pymongo_vectorsearch_utils.util import oid_to_str, str_to_oid + + +def bulk_embed_and_insert_texts( + texts: list[str] | Iterable[str], + metadatas: list[dict] | Generator[dict, Any, Any], + embedding_func: Callable[[list[str]], list[list[float]]], + collection: Collection[Any], + text_key: str, + embedding_key: str, + ids: list[str] | None = None, + **kwargs: Any, +) -> list[str]: + """Bulk insert single batch of texts, embeddings, and optionally ids. + + Important notes on ids: + - If _id or id is a key in the metadatas dicts, one must + pop them and provide as separate list. + - They must be unique. + - If they are not provided, unique ones are created, + stored as bson.ObjectIds internally, and strings in the database. + These will appear in Document.metadata with key, '_id'. + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + embedding_func: A function that generates embedding vectors from the texts. + collection: The MongoDB collection where documents will be inserted. + text_key: The field name where thet text will be stored in each document. + embedding_key: The field name where the embedding will be stored in each document. + ids: Optional list of unique ids that will be used as index in VectorStore. + See note on ids. + """ + if not texts: + return [] + # Compute embedding vectors + embeddings = embedding_func(list(texts)) + if not ids: + ids = [str(ObjectId()) for _ in range(len(list(texts)))] + docs = [ + { + "_id": str_to_oid(i), + text_key: t, + embedding_key: embedding, + **m, + } + for i, t, m, embedding in zip(ids, texts, metadatas, embeddings, strict=False) + ] + operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs] + # insert the documents in MongoDB Atlas + result = collection.bulk_write(operations) + assert result.upserted_ids is not None + return [oid_to_str(_id) for _id in result.upserted_ids.values()] diff --git a/pymongo_vectorsearch_utils/util.py b/pymongo_vectorsearch_utils/util.py new file mode 100644 index 0000000..e1c32a4 --- /dev/null +++ b/pymongo_vectorsearch_utils/util.py @@ -0,0 +1,44 @@ +import logging +from typing import Any + +logger = logging.getLogger(__file__) + + +def str_to_oid(str_repr: str) -> Any | str: + """Attempt to cast string representation of id to MongoDB's internal BSON ObjectId. + + To be consistent with ObjectId, input must be a 24 character hex string. + If it is not, MongoDB will happily use the string in the main _id index. + Importantly, the str representation that comes out of MongoDB will have this form. + + Args: + str_repr: id as string. + + Returns: + ObjectID + """ + from bson import ObjectId + from bson.errors import InvalidId + + try: + return ObjectId(str_repr) + except InvalidId: + logger.debug( + "ObjectIds must be 12-character byte or 24-character hex strings. " + "Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'" + ) + return str_repr + + +def oid_to_str(oid: Any) -> str: + """Convert MongoDB's internal BSON ObjectId into a simple str for compatibility. + + Instructive helper to show where data is coming out of MongoDB. + + Args: + oid: bson.ObjectId + + Returns: + 24 character hex string. + """ + return str(oid) diff --git a/tests/test_operation.py b/tests/test_operation.py new file mode 100644 index 0000000..2303a67 --- /dev/null +++ b/tests/test_operation.py @@ -0,0 +1,268 @@ +"""Tests for operation utilities.""" + +import os +from unittest.mock import Mock + +import pytest +from bson import ObjectId +from pymongo import MongoClient +from pymongo.collection import Collection + +from pymongo_vectorsearch_utils.operation import bulk_embed_and_insert_texts + +DB_NAME = "vectorsearch_utils_test" +COLLECTION_NAME = "test_operation" + + +@pytest.fixture(scope="module") +def client(): + conn_str = os.environ.get("MONGODB_URI", "mongodb://127.0.0.1:27017?directConnection=true") + client = MongoClient(conn_str) + yield client + client.close() + + +@pytest.fixture +def collection(client): + if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): + clxn = client[DB_NAME].create_collection(COLLECTION_NAME) + else: + clxn = client[DB_NAME][COLLECTION_NAME] + clxn.delete_many({}) + yield clxn + clxn.delete_many({}) + + +@pytest.fixture +def mock_embedding_func(): + """Mock embedding function that returns predictable embeddings.""" + + def embedding_func(texts): + return [[float(i), float(i) * 0.5, float(i) * 0.25] for i in range(len(texts))] + + return embedding_func + + +class TestBulkEmbedAndInsertTexts: + def test_empty_texts_returns_empty_list(self, collection: Collection, mock_embedding_func): + result = bulk_embed_and_insert_texts( + texts=[], + metadatas=[], + embedding_func=mock_embedding_func, + collection=collection, + text_key="text", + embedding_key="embedding", + ) + assert result == [] + + def test_basic_insertion_with_generated_ids(self, collection: Collection, mock_embedding_func): + texts = ["text one", "text two"] + metadatas = [{"category": "test_1"}, {"category": "test_2"}] + + result = bulk_embed_and_insert_texts( + texts=texts, + metadatas=metadatas, + embedding_func=mock_embedding_func, + collection=collection, + text_key="content", + embedding_key="vector", + ) + + assert len(result) == 2 + assert all(isinstance(id_str, str) for id_str in result) + + docs = list(collection.find({})) + assert len(docs) == 2 + + for i, doc in enumerate(docs): + assert doc["content"] == texts[i] + assert doc["vector"] == [float(i), float(i) * 0.5, float(i) * 0.25] + assert doc["category"] == metadatas[i]["category"] + assert isinstance(doc["_id"], ObjectId) + + def test_insertion_with_custom_ids(self, collection: Collection, mock_embedding_func): + texts = ["text one"] + metadatas = [{"type": "custom"}] + custom_ids = ["custom_id_123"] + + result = bulk_embed_and_insert_texts( + texts=texts, + metadatas=metadatas, + embedding_func=mock_embedding_func, + collection=collection, + text_key="text", + embedding_key="embedding", + ids=custom_ids, + ) + + assert result == custom_ids + + doc = collection.find_one({"_id": "custom_id_123"}) + assert doc is not None + assert doc["text"] == texts[0] + assert doc["type"] == "custom" + + def test_insertion_with_objectid_string_ids(self, collection: Collection, mock_embedding_func): + texts = ["text one"] + metadatas = [{"test": True}] + object_id_str = str(ObjectId()) + + result = bulk_embed_and_insert_texts( + texts=texts, + metadatas=metadatas, + embedding_func=mock_embedding_func, + collection=collection, + text_key="text", + embedding_key="embedding", + ids=[object_id_str], + ) + + assert result == [object_id_str] + + # Verify document was inserted with ObjectId + doc = collection.find_one({}) + assert doc is not None + assert isinstance(doc["_id"], ObjectId) + assert str(doc["_id"]) == object_id_str + + def test_upsert_behavior(self, collection: Collection, mock_embedding_func): + texts = ["text one"] + metadatas = [{"version": 1}] + custom_id = "upsert_id" + + # First insertion + bulk_embed_and_insert_texts( + texts=texts, + metadatas=metadatas, + embedding_func=mock_embedding_func, + collection=collection, + text_key="text", + embedding_key="embedding", + ids=[custom_id], + ) + + new_metadatas = [{"version": 2}] + bulk_embed_and_insert_texts( + texts=["updated text"], + metadatas=new_metadatas, + embedding_func=mock_embedding_func, + collection=collection, + text_key="text", + embedding_key="embedding", + ids=[custom_id], + ) + + docs = list(collection.find({})) + assert len(docs) == 1 + assert docs[0]["text"] == "updated text" + assert docs[0]["version"] == 2 + + def test_with_generator_metadata(self, collection: Collection, mock_embedding_func): + def metadata_generator(): + yield {"index": 0} + yield {"index": 1} + + result = bulk_embed_and_insert_texts( + texts=["text one", "text two"], + metadatas=metadata_generator(), + embedding_func=mock_embedding_func, + collection=collection, + text_key="text", + embedding_key="embedding", + ) + + assert len(result) == 2 + docs = list(collection.find({}).sort("index", 1)) + assert len(docs) == 2 + assert docs[0]["text"] == "text one" + assert docs[1]["text"] == "text two" + + def test_embedding_function_called_correctly(self, collection: Collection): + texts = ["text one", "text two", "text three"] + metadatas = [{}, {}, {}] + + mock_embedding_func = Mock(return_value=[[1.0], [2.0], [3.0]]) + + bulk_embed_and_insert_texts( + texts=texts, + metadatas=metadatas, + embedding_func=mock_embedding_func, + collection=collection, + text_key="text", + embedding_key="embedding", + ) + + mock_embedding_func.assert_called_once_with(texts) + + def test_large_batch_processing(self, collection: Collection, mock_embedding_func): + num_docs = 100 + texts = [f"text {i}" for i in range(num_docs)] + metadatas = [{"doc_num": i} for i in range(num_docs)] + + result = bulk_embed_and_insert_texts( + texts=texts, + metadatas=metadatas, + embedding_func=mock_embedding_func, + collection=collection, + text_key="text", + embedding_key="embedding", + ) + + assert len(result) == num_docs + assert collection.count_documents({}) == num_docs + + def test_with_additional_kwargs(self, collection: Collection, mock_embedding_func): + texts = ["text one"] + metadatas = [{}] + + result = bulk_embed_and_insert_texts( + texts=texts, + metadatas=metadatas, + embedding_func=mock_embedding_func, + collection=collection, + text_key="text", + embedding_key="embedding", + extra_param="ignored", + ) + + assert len(result) == 1 + + def test_mismatched_lengths_handled_gracefully( + self, collection: Collection, mock_embedding_func + ): + texts = ["text one", "text two"] + metadatas = [{"meta": 1}] # Shorter than texts + + result = bulk_embed_and_insert_texts( + texts=texts, + metadatas=metadatas, + embedding_func=mock_embedding_func, + collection=collection, + text_key="text", + embedding_key="embedding", + ) + + assert len(result) == 1 + docs = list(collection.find({})) + assert len(docs) == 1 + assert docs[0]["text"] == "text one" + + def test_custom_field_names(self, collection: Collection, mock_embedding_func): + texts = ["text one"] + metadatas = [{}] + + bulk_embed_and_insert_texts( + texts=texts, + metadatas=metadatas, + embedding_func=mock_embedding_func, + collection=collection, + text_key="content", + embedding_key="vector", + ) + + doc = collection.find_one({}) + assert doc is not None + assert "content" in doc + assert "vector" in doc + assert doc["content"] == texts[0] + assert doc["vector"] == [0.0, 0.0, 0.0]