diff --git a/sentence_transformers/evaluation/NanoBEIREvaluator.py b/sentence_transformers/evaluation/NanoBEIREvaluator.py index a7cdc2ddb..1aee69c6d 100644 --- a/sentence_transformers/evaluation/NanoBEIREvaluator.py +++ b/sentence_transformers/evaluation/NanoBEIREvaluator.py @@ -2,7 +2,7 @@ import logging import os -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, Literal, get_args import numpy as np from torch import Tensor @@ -35,21 +35,32 @@ "touche2020", ] +LanguageType = Literal[ + "ar", + "de", + "en", + "es", + "fr", + "it", + "no", + "pt", + "sv", +] -dataset_name_to_id = { - "climatefever": "zeta-alpha-ai/NanoClimateFEVER", - "dbpedia": "zeta-alpha-ai/NanoDBPedia", - "fever": "zeta-alpha-ai/NanoFEVER", - "fiqa2018": "zeta-alpha-ai/NanoFiQA2018", - "hotpotqa": "zeta-alpha-ai/NanoHotpotQA", - "msmarco": "zeta-alpha-ai/NanoMSMARCO", - "nfcorpus": "zeta-alpha-ai/NanoNFCorpus", - "nq": "zeta-alpha-ai/NanoNQ", - "quoraretrieval": "zeta-alpha-ai/NanoQuoraRetrieval", - "scidocs": "zeta-alpha-ai/NanoSCIDOCS", - "arguana": "zeta-alpha-ai/NanoArguAna", - "scifact": "zeta-alpha-ai/NanoSciFact", - "touche2020": "zeta-alpha-ai/NanoTouche2020", +dataset_name_to_subset_id = { + "climatefever": "NanoClimateFEVER", + "dbpedia": "NanoDBPedia", + "fever": "NanoFEVER", + "fiqa2018": "NanoFiQA2018", + "hotpotqa": "NanoHotpotQA", + "msmarco": "NanoMSMARCO", + "nfcorpus": "NanoNFCorpus", + "nq": "NanoNQ", + "quoraretrieval": "NanoQuoraRetrieval", + "scidocs": "NanoSCIDOCS", + "arguana": "NanoArguAna", + "scifact": "NanoSciFact", + "touche2020": "NanoTouche2020", } dataset_name_to_human_readable = { @@ -80,6 +91,7 @@ class NanoBEIREvaluator(SentenceEvaluator): Args: dataset_names (List[str]): The names of the datasets to evaluate on. Defaults to all datasets. + language (str): The language of the NanoBEIR collection. Supports Arabic (ar), German (de), English (en), Spanish (es), French (fr), Italian (it), Norwegian (no), Portuguese (pt), and Swedish (sv). Defaults to English (en). mrr_at_k (List[int]): A list of integers representing the values of k for MRR calculation. Defaults to [10]. ndcg_at_k (List[int]): A list of integers representing the values of k for NDCG calculation. Defaults to [10]. accuracy_at_k (List[int]): A list of integers representing the values of k for accuracy calculation. Defaults to [1, 3, 5, 10]. @@ -193,6 +205,7 @@ class NanoBEIREvaluator(SentenceEvaluator): def __init__( self, dataset_names: list[DatasetNameType] | None = None, + language: LanguageType | None = "en", mrr_at_k: list[int] = [10], ndcg_at_k: list[int] = [10], accuracy_at_k: list[int] = [1, 3, 5, 10], @@ -212,7 +225,8 @@ def __init__( ): super().__init__() if dataset_names is None: - dataset_names = list(dataset_name_to_id.keys()) + dataset_names = list(dataset_name_to_subset_id.keys()) + self.language = language self.dataset_names = dataset_names self.aggregate_fn = aggregate_fn self.aggregate_key = aggregate_key @@ -236,6 +250,7 @@ def __init__( self.map_at_k = map_at_k self._validate_dataset_names() + self._validate_language() self._validate_prompts() ir_evaluator_kwargs = { @@ -409,10 +424,11 @@ def _load_dataset(self, dataset_name: DatasetNameType, **ir_evaluator_kwargs) -> ) from datasets import load_dataset - dataset_path = dataset_name_to_id[dataset_name.lower()] - corpus = load_dataset(dataset_path, "corpus", split="train") - queries = load_dataset(dataset_path, "queries", split="train") - qrels = load_dataset(dataset_path, "qrels", split="train") + dataset_path = "lightonai/nanobeir-multilingual" + subset_id = dataset_name_to_subset_id[dataset_name.lower()] + corpus = load_dataset(dataset_path, f"{subset_id}_{self.language}", split="corpus") + queries = load_dataset(dataset_path, f"{subset_id}_{self.language}", split="queries") + qrels = load_dataset(dataset_path, subset_id, split="qrels") corpus_dict = {sample["_id"]: sample["text"] for sample in corpus if len(sample["text"]) > 0} queries_dict = {sample["_id"]: sample["text"] for sample in queries if len(sample["text"]) > 0} qrels_dict = {} @@ -438,11 +454,21 @@ def _validate_dataset_names(self): if len(self.dataset_names) == 0: raise ValueError("dataset_names cannot be empty. Use None to evaluate on all datasets.") if missing_datasets := [ - dataset_name for dataset_name in self.dataset_names if dataset_name.lower() not in dataset_name_to_id + dataset_name + for dataset_name in self.dataset_names + if dataset_name.lower() not in dataset_name_to_subset_id ]: raise ValueError( f"Dataset(s) {missing_datasets} not found in the NanoBEIR collection. " - f"Valid dataset names are: {list(dataset_name_to_id.keys())}" + f"Valid dataset names are: {list(dataset_name_to_subset_id.keys())}" + ) + + def _validate_language(self): + valid_languages = list(get_args(LanguageType)) + if self.language not in valid_languages: + raise ValueError( + f"Language '{self.language}' not found in the NanoBEIR multilingual collection. " + f"Valid languages are: {valid_languages}" ) def _validate_prompts(self): diff --git a/sentence_transformers/sparse_encoder/evaluation/SparseNanoBEIREvaluator.py b/sentence_transformers/sparse_encoder/evaluation/SparseNanoBEIREvaluator.py index d96a3e7a2..097f4a58f 100644 --- a/sentence_transformers/sparse_encoder/evaluation/SparseNanoBEIREvaluator.py +++ b/sentence_transformers/sparse_encoder/evaluation/SparseNanoBEIREvaluator.py @@ -18,7 +18,7 @@ from torch import Tensor from sentence_transformers.evaluation import SimilarityFunction - from sentence_transformers.evaluation.NanoBEIREvaluator import DatasetNameType + from sentence_transformers.evaluation.NanoBEIREvaluator import DatasetNameType, LanguageType from sentence_transformers.sparse_encoder import SparseEncoder logger = logging.getLogger(__name__) @@ -37,6 +37,7 @@ class SparseNanoBEIREvaluator(NanoBEIREvaluator): Args: dataset_names (List[str]): The names of the datasets to evaluate on. Defaults to all datasets. + language (str): The language of the NanoBEIR collection. Supports Arabic (ar), German (de), English (en), Spanish (es), French (fr), Italian (it), Norwegian (no), Portuguese (pt), and Swedish (sv). Defaults to English (en). mrr_at_k (List[int]): A list of integers representing the values of k for MRR calculation. Defaults to [10]. ndcg_at_k (List[int]): A list of integers representing the values of k for NDCG calculation. Defaults to [10]. accuracy_at_k (List[int]): A list of integers representing the values of k for accuracy calculation. Defaults to [1, 3, 5, 10]. @@ -164,6 +165,7 @@ class SparseNanoBEIREvaluator(NanoBEIREvaluator): def __init__( self, dataset_names: list[DatasetNameType] | None = None, + language: LanguageType | None = "en", mrr_at_k: list[int] = [10], ndcg_at_k: list[int] = [10], accuracy_at_k: list[int] = [1, 3, 5, 10], @@ -185,6 +187,7 @@ def __init__( self.sparsity_stats = defaultdict(list) super().__init__( dataset_names=dataset_names, + language=language, mrr_at_k=mrr_at_k, ndcg_at_k=ndcg_at_k, accuracy_at_k=accuracy_at_k, diff --git a/tests/evaluation/test_nanobeir_evaluator.py b/tests/evaluation/test_nanobeir_evaluator.py index c637599ed..438515c9a 100644 --- a/tests/evaluation/test_nanobeir_evaluator.py +++ b/tests/evaluation/test_nanobeir_evaluator.py @@ -61,3 +61,17 @@ def test_nanobeir_evaluator_empty_inputs(): """Test that NanoBEIREvaluator behaves correctly with empty datasets.""" with pytest.raises(ValueError, match="dataset_names cannot be empty. Use None to evaluate on all datasets."): NanoBEIREvaluator(dataset_names=[]) + + +def test_nanobeir_evaluator_invalid_language(): + """Test that NanoBEIREvaluator raises an error for invalid languages.""" + invalid_language = "nl" + + with pytest.raises( + ValueError, + match=re.escape( + r"Language 'nl' not found in the NanoBEIR multilingual collection. " + r"Valid languages are: ['ar', 'de', 'en', 'es', 'fr', 'it', 'no', 'pt', 'sv']" + ), + ): + NanoBEIREvaluator(language=invalid_language)