Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea/
.DS_STORE
__pycache__/
.env
5 changes: 5 additions & 0 deletions pymongo_vectorsearch_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
drop_vector_search_index,
update_vector_search_index,
)
from .operation import bulk_embed_and_insert_texts
from .pipeline import (
combine_pipelines,
final_hybrid_stage,
reciprocal_rank_stage,
text_search_stage,
vector_search_stage,
)
from .util import oid_to_str, str_to_oid

__all__ = [
"__version__",
Expand All @@ -24,4 +26,7 @@
"combine_pipelines",
"reciprocal_rank_stage",
"final_hybrid_stage",
"bulk_embed_and_insert_texts",
"str_to_oid",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to export these two

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we expect to use bulk_embed_and_insert_texts? Or do you mean the two str_to_oid and oid_to_str?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str_to_oid and oid_to_str

"oid_to_str",
]
62 changes: 62 additions & 0 deletions pymongo_vectorsearch_utils/operation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""CRUD utilities and helpers."""

from collections.abc import Callable, Generator, Iterable
from typing import Any

from bson import ObjectId
from pymongo import ReplaceOne
from pymongo.synchronous.collection import Collection

from pymongo_vectorsearch_utils.util import oid_to_str, str_to_oid


def bulk_embed_and_insert_texts(
texts: list[str] | Iterable[str],
metadatas: list[dict] | Generator[dict, Any, Any],
embedding_func: Callable[[list[str]], list[list[float]]],
collection: Collection[Any],
text_key: str,
embedding_key: str,
ids: list[str] | None = None,
**kwargs: Any,
) -> list[str]:
"""Bulk insert single batch of texts, embeddings, and optionally ids.

Important notes on ids:
- If _id or id is a key in the metadatas dicts, one must
pop them and provide as separate list.
- They must be unique.
- If they are not provided, unique ones are created,
stored as bson.ObjectIds internally, and strings in the database.
These will appear in Document.metadata with key, '_id'.

Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
embedding_func: A function that generates embedding vectors from the texts.
collection: The MongoDB collection where documents will be inserted.
text_key: The field name where thet text will be stored in each document.
embedding_key: The field name where the embedding will be stored in each document.
ids: Optional list of unique ids that will be used as index in VectorStore.
See note on ids.
"""
if not texts:
return []
# Compute embedding vectors
embeddings = embedding_func(list(texts))
if not ids:
ids = [str(ObjectId()) for _ in range(len(list(texts)))]
docs = [
{
"_id": str_to_oid(i),
text_key: t,
embedding_key: embedding,
**m,
}
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings, strict=False)
]
operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs]
# insert the documents in MongoDB Atlas
result = collection.bulk_write(operations)
assert result.upserted_ids is not None
return [oid_to_str(_id) for _id in result.upserted_ids.values()]
44 changes: 44 additions & 0 deletions pymongo_vectorsearch_utils/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
from typing import Any

logger = logging.getLogger(__file__)


def str_to_oid(str_repr: str) -> Any | str:
"""Attempt to cast string representation of id to MongoDB's internal BSON ObjectId.

To be consistent with ObjectId, input must be a 24 character hex string.
If it is not, MongoDB will happily use the string in the main _id index.
Importantly, the str representation that comes out of MongoDB will have this form.

Args:
str_repr: id as string.

Returns:
ObjectID
"""
from bson import ObjectId
from bson.errors import InvalidId

try:
return ObjectId(str_repr)
except InvalidId:
logger.debug(
"ObjectIds must be 12-character byte or 24-character hex strings. "
"Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'"
)
return str_repr


def oid_to_str(oid: Any) -> str:
"""Convert MongoDB's internal BSON ObjectId into a simple str for compatibility.

Instructive helper to show where data is coming out of MongoDB.

Args:
oid: bson.ObjectId

Returns:
24 character hex string.
"""
return str(oid)
Loading