Skip to content

Commit 6fd1762

Browse files
authored
Merge pull request #1553 from hanhainebula/master
add metric for evaluation: evaluate_recall_cap
2 parents bbc736f + d166f49 commit 6fd1762

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

FlagEmbedding/abc/evaluation/evaluator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .data_loader import AbsEvalDataLoader
1212
from .searcher import EvalRetriever, EvalReranker
13-
from .utils import evaluate_metrics, evaluate_mrr
13+
from .utils import evaluate_metrics, evaluate_mrr, evaluate_recall_cap
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -340,12 +340,18 @@ def compute_metrics(
340340
results=search_results,
341341
k_values=k_values,
342342
)
343+
recall_cap = evaluate_recall_cap(
344+
qrels=qrels,
345+
results=search_results,
346+
k_values=k_values,
347+
)
343348
scores = {
344349
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
345350
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
346351
**{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
347352
**{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
348353
**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
354+
**{f"recall_cap_at_{k.split('@')[1]}": v for (k, v) in recall_cap.items()},
349355
}
350356
return scores
351357

FlagEmbedding/abc/evaluation/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,45 @@ def evaluate_mrr(
5252
return mrr
5353

5454

55+
# Modified from https://github.com/beir-cellar/beir/blob/f062f038c4bfd19a8ca942a9910b1e0d218759d4/beir/retrieval/custom_metrics.py#L33
56+
def evaluate_recall_cap(
57+
qrels: Dict[str, Dict[str, int]],
58+
results: Dict[str, Dict[str, float]],
59+
k_values: List[int]
60+
) -> Tuple[Dict[str, float]]:
61+
"""Compute capped recall.
62+
63+
Args:
64+
qrels (Dict[str, Dict[str, int]]): Ground truth relevance.
65+
results (Dict[str, Dict[str, float]]): Search results to evaluate.
66+
k_values (List[int]): Cutoffs.
67+
68+
Returns:
69+
Tuple[Dict[str, float]]: Capped recall results at provided k values.
70+
"""
71+
capped_recall = {}
72+
73+
for k in k_values:
74+
capped_recall[f"R_cap@{k}"] = 0.0
75+
76+
k_max = max(k_values)
77+
logging.info("\n")
78+
79+
for query_id, doc_scores in results.items():
80+
top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
81+
query_relevant_docs = [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]
82+
for k in k_values:
83+
retrieved_docs = [row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0]
84+
denominator = min(len(query_relevant_docs), k)
85+
capped_recall[f"R_cap@{k}"] += (len(retrieved_docs) / denominator)
86+
87+
for k in k_values:
88+
capped_recall[f"R_cap@{k}"] = round(capped_recall[f"R_cap@{k}"]/len(qrels), 5)
89+
logging.info("R_cap@{}: {:.4f}".format(k, capped_recall[f"R_cap@{k}"]))
90+
91+
return capped_recall
92+
93+
5594
# Modified from https://github.com/embeddings-benchmark/mteb/blob/18f730696451a5aaa026494cecf288fd5cde9fd0/mteb/evaluation/evaluators/RetrievalEvaluator.py#L501
5695
def evaluate_metrics(
5796
qrels: Dict[str, Dict[str, int]],

0 commit comments

Comments
 (0)