Skip to content

Commit 458dd74

Browse files
committed
Add bulk_embed_and_insert_texts
1 parent 8d776d8 commit 458dd74

File tree

5 files changed

+380
-0
lines changed

5 files changed

+380
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.idea/
22
.DS_STORE
33
__pycache__/
4+
.env

pymongo_vectorsearch_utils/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
drop_vector_search_index,
66
update_vector_search_index,
77
)
8+
from .operation import bulk_embed_and_insert_texts
89
from .pipeline import (
910
combine_pipelines,
1011
final_hybrid_stage,
1112
reciprocal_rank_stage,
1213
text_search_stage,
1314
vector_search_stage,
1415
)
16+
from .util import oid_to_str, str_to_oid
1517

1618
__all__ = [
1719
"__version__",
@@ -24,4 +26,7 @@
2426
"combine_pipelines",
2527
"reciprocal_rank_stage",
2628
"final_hybrid_stage",
29+
"bulk_embed_and_insert_texts",
30+
"str_to_oid",
31+
"oid_to_str",
2732
]
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Tests for CRUD utilities."""
2+
3+
from collections.abc import Callable, Generator, Iterable
4+
from typing import Any
5+
6+
from bson import ObjectId
7+
from pymongo import ReplaceOne
8+
from pymongo.synchronous.collection import Collection
9+
10+
from pymongo_vectorsearch_utils.util import oid_to_str, str_to_oid
11+
12+
13+
def bulk_embed_and_insert_texts(
14+
texts: list[str] | Iterable[str],
15+
metadatas: list[dict] | Generator[dict, Any, Any],
16+
embedding_func: Callable[[list[str]], list[list[float]]],
17+
collection: Collection[Any],
18+
text_key: str,
19+
embedding_key: str,
20+
ids: list[str] | None = None,
21+
**kwargs: Any,
22+
) -> list[str]:
23+
"""Bulk insert single batch of texts, embeddings, and optionally ids.
24+
25+
Important notes on ids:
26+
- If _id or id is a key in the metadatas dicts, one must
27+
pop them and provide as separate list.
28+
- They must be unique.
29+
- If they are not provided, unique ones are created,
30+
stored as bson.ObjectIds internally, and strings in the database.
31+
These will appear in Document.metadata with key, '_id'.
32+
33+
Args:
34+
texts: Iterable of strings to add to the vectorstore.
35+
metadatas: Optional list of metadatas associated with the texts.
36+
embedding_func: A function that generates embedding vectors from the texts.
37+
collection: The MongoDB collection where documents will be inserted.
38+
text_key: The field name where thet text will be stored in each document.
39+
embedding_key: The field name where the embedding will be stored in each document.
40+
ids: Optional list of unique ids that will be used as index in VectorStore.
41+
See note on ids.
42+
"""
43+
if not texts:
44+
return []
45+
# Compute embedding vectors
46+
embeddings = embedding_func(list(texts))
47+
if not ids:
48+
ids = [str(ObjectId()) for _ in range(len(list(texts)))]
49+
docs = [
50+
{
51+
"_id": str_to_oid(i),
52+
text_key: t,
53+
embedding_key: embedding,
54+
**m,
55+
}
56+
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings, strict=False)
57+
]
58+
operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs]
59+
# insert the documents in MongoDB Atlas
60+
result = collection.bulk_write(operations)
61+
assert result.upserted_ids is not None
62+
return [oid_to_str(_id) for _id in result.upserted_ids.values()]

pymongo_vectorsearch_utils/util.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import logging
2+
from typing import Any
3+
4+
logger = logging.getLogger(__file__)
5+
6+
7+
def str_to_oid(str_repr: str) -> Any | str:
8+
"""Attempt to cast string representation of id to MongoDB's internal BSON ObjectId.
9+
10+
To be consistent with ObjectId, input must be a 24 character hex string.
11+
If it is not, MongoDB will happily use the string in the main _id index.
12+
Importantly, the str representation that comes out of MongoDB will have this form.
13+
14+
Args:
15+
str_repr: id as string.
16+
17+
Returns:
18+
ObjectID
19+
"""
20+
from bson import ObjectId
21+
from bson.errors import InvalidId
22+
23+
try:
24+
return ObjectId(str_repr)
25+
except InvalidId:
26+
logger.debug(
27+
"ObjectIds must be 12-character byte or 24-character hex strings. "
28+
"Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'"
29+
)
30+
return str_repr
31+
32+
33+
def oid_to_str(oid: Any) -> str:
34+
"""Convert MongoDB's internal BSON ObjectId into a simple str for compatibility.
35+
36+
Instructive helper to show where data is coming out of MongoDB.
37+
38+
Args:
39+
oid: bson.ObjectId
40+
41+
Returns:
42+
24 character hex string.
43+
"""
44+
return str(oid)

0 commit comments

Comments
 (0)