@@ -52,6 +52,45 @@ def evaluate_mrr(
5252 return mrr
5353
5454
55+ # Modified from https://github.com/beir-cellar/beir/blob/f062f038c4bfd19a8ca942a9910b1e0d218759d4/beir/retrieval/custom_metrics.py#L33
56+ def evaluate_recall_cap (
57+ qrels : Dict [str , Dict [str , int ]],
58+ results : Dict [str , Dict [str , float ]],
59+ k_values : List [int ]
60+ ) -> Tuple [Dict [str , float ]]:
61+ """Compute capped recall.
62+
63+ Args:
64+ qrels (Dict[str, Dict[str, int]]): Ground truth relevance.
65+ results (Dict[str, Dict[str, float]]): Search results to evaluate.
66+ k_values (List[int]): Cutoffs.
67+
68+ Returns:
69+ Tuple[Dict[str, float]]: Capped recall results at provided k values.
70+ """
71+ capped_recall = {}
72+
73+ for k in k_values :
74+ capped_recall [f"R_cap@{ k } " ] = 0.0
75+
76+ k_max = max (k_values )
77+ logging .info ("\n " )
78+
79+ for query_id , doc_scores in results .items ():
80+ top_hits = sorted (doc_scores .items (), key = lambda item : item [1 ], reverse = True )[0 :k_max ]
81+ query_relevant_docs = [doc_id for doc_id in qrels [query_id ] if qrels [query_id ][doc_id ] > 0 ]
82+ for k in k_values :
83+ retrieved_docs = [row [0 ] for row in top_hits [0 :k ] if qrels [query_id ].get (row [0 ], 0 ) > 0 ]
84+ denominator = min (len (query_relevant_docs ), k )
85+ capped_recall [f"R_cap@{ k } " ] += (len (retrieved_docs ) / denominator )
86+
87+ for k in k_values :
88+ capped_recall [f"R_cap@{ k } " ] = round (capped_recall [f"R_cap@{ k } " ]/ len (qrels ), 5 )
89+ logging .info ("R_cap@{}: {:.4f}" .format (k , capped_recall [f"R_cap@{ k } " ]))
90+
91+ return capped_recall
92+
93+
5594# Modified from https://github.com/embeddings-benchmark/mteb/blob/18f730696451a5aaa026494cecf288fd5cde9fd0/mteb/evaluation/evaluators/RetrievalEvaluator.py#L501
5695def evaluate_metrics (
5796 qrels : Dict [str , Dict [str , int ]],
0 commit comments