Skip to content

Commit bf6998f

Browse files
authored
feat(INTPYTHON-542): support weighted RRF (#168)
1 parent 065c703 commit bf6998f

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

libs/langchain-mongodb/langchain_mongodb/pipelines.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -102,30 +102,34 @@ def combine_pipelines(
102102

103103

104104
def reciprocal_rank_stage(
105-
score_field: str, penalty: float = 0, **kwargs: Any
105+
score_field: str, penalty: float = 0, weight: float = 1, **kwargs: Any
106106
) -> List[Dict[str, Any]]:
107-
"""Stage adds Reciprocal Rank Fusion weighting.
107+
"""
108+
Stage adds Weighted Reciprocal Rank Fusion (WRRF) scoring.
108109
109-
First, it pushes documents retrieved from previous stage
110-
into a temporary sub-document. It then unwinds to establish
111-
the rank to each and applies the penalty.
110+
First, it groups documents into an array, assigns rank by array index,
111+
and then computes a weighted RRF score.
112112
113113
Args:
114-
score_field: A unique string to identify the search being ranked
115-
penalty: A non-negative float.
116-
extra_fields: Any fields other than text_field that one wishes to keep.
114+
score_field: A unique string to identify the search being ranked.
115+
penalty: A non-negative float (e.g., 60 for RRF-60). Controls the denominator.
116+
weight: A float multiplier for this source's importance.
117+
**kwargs: Ignored; allows future extensions or passthrough args.
117118
118119
Returns:
119-
RRF score
120+
Aggregation pipeline stage for weighted RRF scoring.
120121
"""
121122

122-
rrf_pipeline = [
123+
return [
123124
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
124125
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
125126
{
126127
"$addFields": {
127128
f"docs.{score_field}": {
128-
"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]
129+
"$multiply": [
130+
weight,
131+
{"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]},
132+
]
129133
},
130134
"docs.rank": "$rank",
131135
"_id": "$docs._id",
@@ -134,8 +138,6 @@ def reciprocal_rank_stage(
134138
{"$replaceRoot": {"newRoot": "$docs"}},
135139
]
136140

137-
return rrf_pipeline # type: ignore
138-
139141

140142
def final_hybrid_stage(
141143
scores_fields: List[str], limit: int, **kwargs: Any

libs/langchain-mongodb/langchain_mongodb/retrievers/hybrid_search.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ class MongoDBAtlasHybridSearchRetriever(BaseRetriever):
4343
"""Penalty applied to vector search results in RRF: scores=1/(rank + penalty)"""
4444
fulltext_penalty: float = 60.0
4545
"""Penalty applied to full-text search results in RRF: scores=1/(rank + penalty)"""
46+
vector_weight: float = 1.0
47+
"""Weight applied to vector search results in RRF: score = weight * (1 / (rank + penalty + 1))"""
48+
fulltext_weight: float = 1.0
49+
"""Weight applied to full-text search results in RRF: score = weight * (1 / (rank + penalty + 1))"""
4650
show_embeddings: float = False
4751
"""If true, returned Document metadata will include vectors."""
4852
top_k: Annotated[
@@ -95,7 +99,11 @@ def _get_relevant_documents(
9599
oversampling_factor=self.oversampling_factor,
96100
)
97101
]
98-
vector_pipeline += reciprocal_rank_stage("vector_score", self.vector_penalty)
102+
vector_pipeline += reciprocal_rank_stage(
103+
score_field="vector_score",
104+
penalty=self.vector_penalty,
105+
weight=self.vector_weight,
106+
)
99107

100108
combine_pipelines(pipeline, vector_pipeline, self.collection.name)
101109

@@ -109,7 +117,11 @@ def _get_relevant_documents(
109117
)
110118

111119
text_pipeline.extend(
112-
reciprocal_rank_stage("fulltext_score", self.fulltext_penalty)
120+
reciprocal_rank_stage(
121+
score_field="fulltext_score",
122+
penalty=self.fulltext_penalty,
123+
weight=self.fulltext_weight,
124+
)
113125
)
114126

115127
combine_pipelines(pipeline, text_pipeline, self.collection.name)

0 commit comments

Comments
 (0)