|
12 | 12 | from pymongo_search_utils import ( |
13 | 13 | combine_pipelines, # noqa: F401 |
14 | 14 | final_hybrid_stage, # noqa: F401 |
15 | | - reciprocal_rank_stage, # noqa: F401 |
16 | 15 | ) |
17 | 16 |
|
18 | 17 |
|
@@ -95,3 +94,41 @@ def vector_search_stage( |
95 | 94 | if filter: |
96 | 95 | stage["filter"] = filter |
97 | 96 | return {"$vectorSearch": stage} |
| 97 | + |
| 98 | + |
| 99 | +def reciprocal_rank_stage( |
| 100 | + score_field: str, penalty: float = 0, weight: float = 1, **kwargs: Any |
| 101 | +) -> List[Dict[str, Any]]: |
| 102 | + """ |
| 103 | + Stage adds Weighted Reciprocal Rank Fusion (WRRF) scoring. |
| 104 | +
|
| 105 | + First, it groups documents into an array, assigns rank by array index, |
| 106 | + and then computes a weighted RRF score. |
| 107 | +
|
| 108 | + Args: |
| 109 | + score_field: A unique string to identify the search being ranked. |
| 110 | + penalty: A non-negative float (e.g., 60 for RRF-60). Controls the denominator. |
| 111 | + weight: A float multiplier for this source's importance. |
| 112 | + **kwargs: Ignored; allows future extensions or passthrough args. |
| 113 | +
|
| 114 | + Returns: |
| 115 | + Aggregation pipeline stage for weighted RRF scoring. |
| 116 | + """ |
| 117 | + |
| 118 | + return [ |
| 119 | + {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}}, |
| 120 | + {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}, |
| 121 | + { |
| 122 | + "$addFields": { |
| 123 | + f"docs.{score_field}": { |
| 124 | + "$multiply": [ |
| 125 | + weight, |
| 126 | + {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]}, |
| 127 | + ] |
| 128 | + }, |
| 129 | + "docs.rank": "$rank", |
| 130 | + "_id": "$docs._id", |
| 131 | + } |
| 132 | + }, |
| 133 | + {"$replaceRoot": {"newRoot": "$docs"}}, |
| 134 | + ] |
0 commit comments