Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 48 additions & 22 deletions sentence_transformers/evaluation/NanoBEIREvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Callable, Literal, get_args

import numpy as np
from torch import Tensor
Expand Down Expand Up @@ -35,21 +35,32 @@
"touche2020",
]

LanguageType = Literal[
"ar",
"de",
"en",
"es",
"fr",
"it",
"no",
"pt",
"sv",
]

dataset_name_to_id = {
"climatefever": "zeta-alpha-ai/NanoClimateFEVER",
"dbpedia": "zeta-alpha-ai/NanoDBPedia",
"fever": "zeta-alpha-ai/NanoFEVER",
"fiqa2018": "zeta-alpha-ai/NanoFiQA2018",
"hotpotqa": "zeta-alpha-ai/NanoHotpotQA",
"msmarco": "zeta-alpha-ai/NanoMSMARCO",
"nfcorpus": "zeta-alpha-ai/NanoNFCorpus",
"nq": "zeta-alpha-ai/NanoNQ",
"quoraretrieval": "zeta-alpha-ai/NanoQuoraRetrieval",
"scidocs": "zeta-alpha-ai/NanoSCIDOCS",
"arguana": "zeta-alpha-ai/NanoArguAna",
"scifact": "zeta-alpha-ai/NanoSciFact",
"touche2020": "zeta-alpha-ai/NanoTouche2020",
dataset_name_to_subset_id = {
"climatefever": "NanoClimateFEVER",
"dbpedia": "NanoDBPedia",
"fever": "NanoFEVER",
"fiqa2018": "NanoFiQA2018",
"hotpotqa": "NanoHotpotQA",
"msmarco": "NanoMSMARCO",
"nfcorpus": "NanoNFCorpus",
"nq": "NanoNQ",
"quoraretrieval": "NanoQuoraRetrieval",
"scidocs": "NanoSCIDOCS",
"arguana": "NanoArguAna",
"scifact": "NanoSciFact",
"touche2020": "NanoTouche2020",
}

dataset_name_to_human_readable = {
Expand Down Expand Up @@ -80,6 +91,7 @@ class NanoBEIREvaluator(SentenceEvaluator):

Args:
dataset_names (List[str]): The names of the datasets to evaluate on. Defaults to all datasets.
language (str): The language of the NanoBEIR collection. Supports Arabic (ar), German (de), English (en), Spanish (es), French (fr), Italian (it), Norwegian (no), Portuguese (pt), and Swedish (sv). Defaults to English (en).
mrr_at_k (List[int]): A list of integers representing the values of k for MRR calculation. Defaults to [10].
ndcg_at_k (List[int]): A list of integers representing the values of k for NDCG calculation. Defaults to [10].
accuracy_at_k (List[int]): A list of integers representing the values of k for accuracy calculation. Defaults to [1, 3, 5, 10].
Expand Down Expand Up @@ -193,6 +205,7 @@ class NanoBEIREvaluator(SentenceEvaluator):
def __init__(
self,
dataset_names: list[DatasetNameType] | None = None,
language: LanguageType | None = "en",
mrr_at_k: list[int] = [10],
ndcg_at_k: list[int] = [10],
accuracy_at_k: list[int] = [1, 3, 5, 10],
Expand All @@ -212,7 +225,8 @@ def __init__(
):
super().__init__()
if dataset_names is None:
dataset_names = list(dataset_name_to_id.keys())
dataset_names = list(dataset_name_to_subset_id.keys())
self.language = language
self.dataset_names = dataset_names
self.aggregate_fn = aggregate_fn
self.aggregate_key = aggregate_key
Expand All @@ -236,6 +250,7 @@ def __init__(
self.map_at_k = map_at_k

self._validate_dataset_names()
self._validate_language()
self._validate_prompts()

ir_evaluator_kwargs = {
Expand Down Expand Up @@ -409,10 +424,11 @@ def _load_dataset(self, dataset_name: DatasetNameType, **ir_evaluator_kwargs) ->
)
from datasets import load_dataset

dataset_path = dataset_name_to_id[dataset_name.lower()]
corpus = load_dataset(dataset_path, "corpus", split="train")
queries = load_dataset(dataset_path, "queries", split="train")
qrels = load_dataset(dataset_path, "qrels", split="train")
dataset_path = "lightonai/nanobeir-multilingual"
subset_id = dataset_name_to_subset_id[dataset_name.lower()]
corpus = load_dataset(dataset_path, f"{subset_id}_{self.language}", split="corpus")
queries = load_dataset(dataset_path, f"{subset_id}_{self.language}", split="queries")
qrels = load_dataset(dataset_path, subset_id, split="qrels")
corpus_dict = {sample["_id"]: sample["text"] for sample in corpus if len(sample["text"]) > 0}
queries_dict = {sample["_id"]: sample["text"] for sample in queries if len(sample["text"]) > 0}
qrels_dict = {}
Expand All @@ -438,11 +454,21 @@ def _validate_dataset_names(self):
if len(self.dataset_names) == 0:
raise ValueError("dataset_names cannot be empty. Use None to evaluate on all datasets.")
if missing_datasets := [
dataset_name for dataset_name in self.dataset_names if dataset_name.lower() not in dataset_name_to_id
dataset_name
for dataset_name in self.dataset_names
if dataset_name.lower() not in dataset_name_to_subset_id
]:
raise ValueError(
f"Dataset(s) {missing_datasets} not found in the NanoBEIR collection. "
f"Valid dataset names are: {list(dataset_name_to_id.keys())}"
f"Valid dataset names are: {list(dataset_name_to_subset_id.keys())}"
)

def _validate_language(self):
valid_languages = list(get_args(LanguageType))
if self.language not in valid_languages:
raise ValueError(
f"Language '{self.language}' not found in the NanoBEIR multilingual collection. "
f"Valid languages are: {valid_languages}"
)

def _validate_prompts(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import Tensor

from sentence_transformers.evaluation import SimilarityFunction
from sentence_transformers.evaluation.NanoBEIREvaluator import DatasetNameType
from sentence_transformers.evaluation.NanoBEIREvaluator import DatasetNameType, LanguageType
from sentence_transformers.sparse_encoder import SparseEncoder

logger = logging.getLogger(__name__)
Expand All @@ -37,6 +37,7 @@ class SparseNanoBEIREvaluator(NanoBEIREvaluator):

Args:
dataset_names (List[str]): The names of the datasets to evaluate on. Defaults to all datasets.
language (str): The language of the NanoBEIR collection. Supports Arabic (ar), German (de), English (en), Spanish (es), French (fr), Italian (it), Norwegian (no), Portuguese (pt), and Swedish (sv). Defaults to English (en).
mrr_at_k (List[int]): A list of integers representing the values of k for MRR calculation. Defaults to [10].
ndcg_at_k (List[int]): A list of integers representing the values of k for NDCG calculation. Defaults to [10].
accuracy_at_k (List[int]): A list of integers representing the values of k for accuracy calculation. Defaults to [1, 3, 5, 10].
Expand Down Expand Up @@ -164,6 +165,7 @@ class SparseNanoBEIREvaluator(NanoBEIREvaluator):
def __init__(
self,
dataset_names: list[DatasetNameType] | None = None,
language: LanguageType | None = "en",
mrr_at_k: list[int] = [10],
ndcg_at_k: list[int] = [10],
accuracy_at_k: list[int] = [1, 3, 5, 10],
Expand All @@ -185,6 +187,7 @@ def __init__(
self.sparsity_stats = defaultdict(list)
super().__init__(
dataset_names=dataset_names,
language=language,
mrr_at_k=mrr_at_k,
ndcg_at_k=ndcg_at_k,
accuracy_at_k=accuracy_at_k,
Expand Down
14 changes: 14 additions & 0 deletions tests/evaluation/test_nanobeir_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,17 @@ def test_nanobeir_evaluator_empty_inputs():
"""Test that NanoBEIREvaluator behaves correctly with empty datasets."""
with pytest.raises(ValueError, match="dataset_names cannot be empty. Use None to evaluate on all datasets."):
NanoBEIREvaluator(dataset_names=[])


def test_nanobeir_evaluator_invalid_language():
"""Test that NanoBEIREvaluator raises an error for invalid languages."""
invalid_language = "nl"

with pytest.raises(
ValueError,
match=re.escape(
r"Language 'nl' not found in the NanoBEIR multilingual collection. "
r"Valid languages are: ['ar', 'de', 'en', 'es', 'fr', 'it', 'no', 'pt', 'sv']"
),
):
NanoBEIREvaluator(language=invalid_language)