Skip to content

Commit 17fce12

Browse files
authored
Merge branch 'main' into patch-2
2 parents c247f28 + ffcf7ed commit 17fce12

File tree

7 files changed

+66
-329
lines changed

7 files changed

+66
-329
lines changed

libs/langchain-mongodb/langchain_mongodb/index.py

Lines changed: 7 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@
22

33
import logging
44
from time import monotonic, sleep
5-
from typing import Any, Callable, Dict, List, Optional, Union
5+
from typing import Any, Callable, Dict, List, Optional
66

77
from pymongo.collection import Collection
8-
from pymongo.operations import SearchIndexModel
8+
from pymongo_search_utils import (
9+
create_fulltext_search_index, # noqa: F401
10+
create_vector_search_index, # noqa: F401
11+
drop_vector_search_index, # noqa: F401
12+
update_vector_search_index, # noqa: F401
13+
)
914

1015
logger = logging.getLogger(__file__)
1116

@@ -37,139 +42,6 @@ def _vector_search_index_definition(
3742
return definition
3843

3944

40-
def create_vector_search_index(
41-
collection: Collection,
42-
index_name: str,
43-
dimensions: int,
44-
path: str,
45-
similarity: str,
46-
filters: Optional[List[str]] = None,
47-
vector_index_options: dict | None = None,
48-
*,
49-
wait_until_complete: Optional[float] = None,
50-
**kwargs: Any,
51-
) -> None:
52-
"""Experimental Utility function to create a vector search index
53-
54-
Args:
55-
collection (Collection): MongoDB Collection
56-
index_name (str): Name of Index
57-
dimensions (int): Number of dimensions in embedding
58-
path (str): field with vector embedding
59-
similarity (str): The similarity score used for the index
60-
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
61-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
62-
until search index is ready.
63-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
64-
"""
65-
logger.info("Creating Search Index %s on %s", index_name, collection.name)
66-
67-
if collection.name not in collection.database.list_collection_names(
68-
authorizedCollections=True
69-
):
70-
collection.database.create_collection(collection.name)
71-
72-
result = collection.create_search_index(
73-
SearchIndexModel(
74-
definition=_vector_search_index_definition(
75-
dimensions=dimensions,
76-
path=path,
77-
similarity=similarity,
78-
filters=filters,
79-
vector_index_options=vector_index_options,
80-
**kwargs,
81-
),
82-
name=index_name,
83-
type="vectorSearch",
84-
)
85-
)
86-
87-
if wait_until_complete:
88-
_wait_for_predicate(
89-
predicate=lambda: _is_index_ready(collection, index_name),
90-
err=f"{index_name=} did not complete in {wait_until_complete}!",
91-
timeout=wait_until_complete,
92-
)
93-
logger.info(result)
94-
95-
96-
def drop_vector_search_index(
97-
collection: Collection,
98-
index_name: str,
99-
*,
100-
wait_until_complete: Optional[float] = None,
101-
) -> None:
102-
"""Drop a created vector search index
103-
104-
Args:
105-
collection (Collection): MongoDB Collection with index to be dropped
106-
index_name (str): Name of the MongoDB index
107-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
108-
until search index is ready.
109-
"""
110-
logger.info(
111-
"Dropping Search Index %s from Collection: %s", index_name, collection.name
112-
)
113-
collection.drop_search_index(index_name)
114-
if wait_until_complete:
115-
_wait_for_predicate(
116-
predicate=lambda: len(list(collection.list_search_indexes())) == 0,
117-
err=f"Index {index_name} did not drop in {wait_until_complete}!",
118-
timeout=wait_until_complete,
119-
)
120-
logger.info("Vector Search index %s.%s dropped", collection.name, index_name)
121-
122-
123-
def update_vector_search_index(
124-
collection: Collection,
125-
index_name: str,
126-
dimensions: int,
127-
path: str,
128-
similarity: str,
129-
filters: Optional[List[str]] = None,
130-
vector_index_options: dict | None = None,
131-
*,
132-
wait_until_complete: Optional[float] = None,
133-
**kwargs: Any,
134-
) -> None:
135-
"""Update a search index.
136-
137-
Replace the existing index definition with the provided definition.
138-
139-
Args:
140-
collection (Collection): MongoDB Collection
141-
index_name (str): Name of Index
142-
dimensions (int): Number of dimensions in embedding
143-
path (str): field with vector embedding
144-
similarity (str): The similarity score used for the index.
145-
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
146-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
147-
until search index is ready.
148-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
149-
"""
150-
logger.info(
151-
"Updating Search Index %s from Collection: %s", index_name, collection.name
152-
)
153-
collection.update_search_index(
154-
name=index_name,
155-
definition=_vector_search_index_definition(
156-
dimensions=dimensions,
157-
path=path,
158-
similarity=similarity,
159-
filters=filters,
160-
vector_index_options=vector_index_options,
161-
**kwargs,
162-
),
163-
)
164-
if wait_until_complete:
165-
_wait_for_predicate(
166-
predicate=lambda: _is_index_ready(collection, index_name),
167-
err=f"Index {index_name} update did not complete in {wait_until_complete}!",
168-
timeout=wait_until_complete,
169-
)
170-
logger.info("Update succeeded")
171-
172-
17345
def _is_index_ready(collection: Collection, index_name: str) -> bool:
17446
"""Check for the index name in the list of available search indexes to see if the
17547
specified index is of status READY
@@ -206,50 +78,3 @@ def _wait_for_predicate(
20678
if monotonic() - start > timeout:
20779
raise TimeoutError(err)
20880
sleep(interval)
209-
210-
211-
def create_fulltext_search_index(
212-
collection: Collection,
213-
index_name: str,
214-
field: Union[str, List[str]],
215-
*,
216-
wait_until_complete: Optional[float] = None,
217-
**kwargs: Any,
218-
) -> None:
219-
"""Experimental Utility function to create an Atlas Search index
220-
221-
Args:
222-
collection (Collection): MongoDB Collection
223-
index_name (str): Name of Index
224-
field (str): Field to index
225-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
226-
until search index is ready
227-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
228-
"""
229-
logger.info("Creating Search Index %s on %s", index_name, collection.name)
230-
231-
if collection.name not in collection.database.list_collection_names(
232-
authorizedCollections=True
233-
):
234-
collection.database.create_collection(collection.name)
235-
236-
if isinstance(field, str):
237-
fields_definition = {field: [{"type": "string"}]}
238-
else:
239-
fields_definition = {f: [{"type": "string"}] for f in field}
240-
definition = {"mappings": {"dynamic": False, "fields": fields_definition}}
241-
result = collection.create_search_index(
242-
SearchIndexModel(
243-
definition=definition,
244-
name=index_name,
245-
type="search",
246-
**kwargs,
247-
)
248-
)
249-
if wait_until_complete:
250-
_wait_for_predicate(
251-
predicate=lambda: _is_index_ready(collection, index_name),
252-
err=f"{index_name=} did not complete in {wait_until_complete}!",
253-
timeout=wait_until_complete,
254-
)
255-
logger.info(result)

libs/langchain-mongodb/langchain_mongodb/pipelines.py

Lines changed: 7 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99

1010
from typing import Any, Dict, List, Optional, Union
1111

12+
from pymongo_search_utils import (
13+
combine_pipelines, # noqa: F401
14+
final_hybrid_stage, # noqa: F401
15+
reciprocal_rank_stage, # noqa: F401
16+
vector_search_stage, # noqa: F401
17+
)
18+
1219

1320
def text_search_stage(
1421
query: str,
@@ -48,115 +55,3 @@ def text_search_stage(
4855
pipeline.append({"$limit": limit}) # type: ignore
4956

5057
return pipeline # type: ignore
51-
52-
53-
def vector_search_stage(
54-
query_vector: List[float],
55-
search_field: str,
56-
index_name: str,
57-
top_k: int = 4,
58-
filter: Optional[Dict[str, Any]] = None,
59-
oversampling_factor: int = 10,
60-
**kwargs: Any,
61-
) -> Dict[str, Any]: # noqa: E501
62-
"""Vector Search Stage without Scores.
63-
64-
Scoring is applied later depending on strategy.
65-
vector search includes a vectorSearchScore that is typically used.
66-
hybrid uses Reciprocal Rank Fusion.
67-
68-
Args:
69-
query_vector: List of embedding vector
70-
search_field: Field in Collection containing embedding vectors
71-
index_name: Name of Atlas Vector Search Index tied to Collection
72-
top_k: Number of documents to return
73-
oversampling_factor: this times limit is the number of candidates
74-
filter: MQL match expression comparing an indexed field.
75-
Some operators are not supported.
76-
See `vectorSearch filter docs <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
77-
78-
79-
Returns:
80-
Dictionary defining the $vectorSearch
81-
"""
82-
stage = {
83-
"index": index_name,
84-
"path": search_field,
85-
"queryVector": query_vector,
86-
"numCandidates": top_k * oversampling_factor,
87-
"limit": top_k,
88-
}
89-
if filter:
90-
stage["filter"] = filter
91-
return {"$vectorSearch": stage}
92-
93-
94-
def combine_pipelines(
95-
pipeline: List[Any], stage: List[Dict[str, Any]], collection_name: str
96-
) -> None:
97-
"""Combines two aggregations into a single result set in-place."""
98-
if pipeline:
99-
pipeline.append({"$unionWith": {"coll": collection_name, "pipeline": stage}})
100-
else:
101-
pipeline.extend(stage)
102-
103-
104-
def reciprocal_rank_stage(
105-
score_field: str, penalty: float = 0, weight: float = 1, **kwargs: Any
106-
) -> List[Dict[str, Any]]:
107-
"""
108-
Stage adds Weighted Reciprocal Rank Fusion (WRRF) scoring.
109-
110-
First, it groups documents into an array, assigns rank by array index,
111-
and then computes a weighted RRF score.
112-
113-
Args:
114-
score_field: A unique string to identify the search being ranked.
115-
penalty: A non-negative float (e.g., 60 for RRF-60). Controls the denominator.
116-
weight: A float multiplier for this source's importance.
117-
**kwargs: Ignored; allows future extensions or passthrough args.
118-
119-
Returns:
120-
Aggregation pipeline stage for weighted RRF scoring.
121-
"""
122-
123-
return [
124-
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
125-
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
126-
{
127-
"$addFields": {
128-
f"docs.{score_field}": {
129-
"$multiply": [
130-
weight,
131-
{"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]},
132-
]
133-
},
134-
"docs.rank": "$rank",
135-
"_id": "$docs._id",
136-
}
137-
},
138-
{"$replaceRoot": {"newRoot": "$docs"}},
139-
]
140-
141-
142-
def final_hybrid_stage(
143-
scores_fields: List[str], limit: int, **kwargs: Any
144-
) -> List[Dict[str, Any]]:
145-
"""Sum weighted scores, sort, and apply limit.
146-
147-
Args:
148-
scores_fields: List of fields given to scores of vector and text searches
149-
limit: Number of documents to return
150-
151-
Returns:
152-
Final aggregation stages
153-
"""
154-
155-
return [
156-
{"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}},
157-
{"$replaceRoot": {"newRoot": "$docs"}},
158-
{"$set": {score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}},
159-
{"$addFields": {"score": {"$add": [f"${score}" for score in scores_fields]}}},
160-
{"$sort": {"score": -1}},
161-
{"$limit": limit},
162-
]

libs/langchain-mongodb/langchain_mongodb/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import numpy as np
2727
from pymongo import MongoClient
2828
from pymongo.driver_info import DriverInfo
29+
from pymongo_search_utils import append_client_metadata
2930

3031
logger = logging.getLogger(__name__)
3132

@@ -35,9 +36,7 @@
3536

3637

3738
def _append_client_metadata(client: MongoClient) -> None:
38-
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
39-
if callable(client.append_metadata):
40-
client.append_metadata(DRIVER_METADATA)
39+
append_client_metadata(client=client, driver_info=DRIVER_METADATA)
4140

4241

4342
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:

0 commit comments

Comments
 (0)