diff --git a/FlagEmbedding/evaluation/bright/__init__.py b/FlagEmbedding/evaluation/bright/__init__.py new file mode 100644 index 00000000..15f0d65b --- /dev/null +++ b/FlagEmbedding/evaluation/bright/__init__.py @@ -0,0 +1,17 @@ +from FlagEmbedding.abc.evaluation import ( + AbsEvalModelArgs as BrightEvalModelArgs, +) + +from .data_loader import BrightShortEvalDataLoader, BrightLongEvalDataLoader +from .arguments import BrightEvalArgs +from .runner import BrightEvalRunner +from .searcher import BrightEvalDenseRetriever + +__all__ = [ + "BrightEvalArgs", + "BrightEvalModelArgs", + "BrightEvalRunner", + "BrightEvalDenseRetriever", + "BrightShortEvalDataLoader", + "BrightLongEvalDataLoader", +] diff --git a/FlagEmbedding/evaluation/bright/__main__.py b/FlagEmbedding/evaluation/bright/__main__.py new file mode 100644 index 00000000..b14d07ca --- /dev/null +++ b/FlagEmbedding/evaluation/bright/__main__.py @@ -0,0 +1,28 @@ +from transformers import HfArgumentParser + +from FlagEmbedding.evaluation.bright import ( + BrightEvalArgs, BrightEvalModelArgs, + BrightEvalRunner +) + + +def main(): + parser = HfArgumentParser(( + BrightEvalArgs, + BrightEvalModelArgs + )) + + eval_args, model_args = parser.parse_args_into_dataclasses() + eval_args: BrightEvalArgs + model_args: BrightEvalModelArgs + + runner = BrightEvalRunner( + eval_args=eval_args, + model_args=model_args + ) + + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/evaluation/bright/arguments.py b/FlagEmbedding/evaluation/bright/arguments.py new file mode 100644 index 00000000..1e1b5bc3 --- /dev/null +++ b/FlagEmbedding/evaluation/bright/arguments.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass, field + +from FlagEmbedding.abc.evaluation.arguments import AbsEvalArgs + + +@dataclass +class BrightEvalArgs(AbsEvalArgs): + """ + Argument class for Bright evaluation. + """ + task_type: str = field( + default="short", metadata={"help": "The task type to evaluate on. Available options: ['short', 'long']. Default: short", "choices": ["short", "long"]} + ) + use_special_instructions: bool = field( + default=True, metadata={"help": "Whether to use specific instructions in `prompts.py` for evaluation. Default: True"} + ) diff --git a/FlagEmbedding/evaluation/bright/data_loader.py b/FlagEmbedding/evaluation/bright/data_loader.py new file mode 100644 index 00000000..3884996d --- /dev/null +++ b/FlagEmbedding/evaluation/bright/data_loader.py @@ -0,0 +1,399 @@ +import os +import json +import logging +import datasets +from tqdm import tqdm +from typing import List, Optional +from collections import defaultdict + +from FlagEmbedding.abc.evaluation import AbsEvalDataLoader + +logger = logging.getLogger(__name__) + + +class BrightShortEvalDataLoader(AbsEvalDataLoader): + """ + Data loader class for Bright(short). + """ + def available_dataset_names(self) -> List[str]: + """ + Get the available dataset names. + + Returns: + List[str]: All the available dataset names. + """ + return [ + # StackExchange + "biology", "earth_science", "economics", "psychology", "robotics", "stackoverflow", "sustainable_living", + # Coding + "leetcode", "pony", + # Theorem-based + "aops", "theoremqa_questions", "theoremqa_theorems" + ] + + def available_splits(self, dataset_name: str) -> List[str]: + """ + Get the avaialble splits. + + Args: + dataset_name (str): Dataset name. + + Returns: + List[str]: All the available splits for the dataset. + """ + return [ + # normal splits + "examples", + # w/ reasoning splits + "Gemini-1.0_reason", "claude-3-opus_reason", "gpt4_reason", "grit_reason", "llama3-70b_reason", + ] + + def _load_remote_corpus( + self, + dataset_name: str, + save_dir: Optional[str] = None + ) -> datasets.DatasetDict: + """Load the corpus dataset from HF. + + Args: + dataset_name (str): Name of the dataset. + save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``. + + Returns: + datasets.DatasetDict: Loaded datasets instance of corpus. + """ + corpus = datasets.load_dataset( + "xlangai/bright", "documents", + cache_dir=self.cache_dir, + download_mode=self.hf_download_mode + )[dataset_name] + + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, "corpus.jsonl") + corpus_dict = {} + with open(save_path, "w", encoding="utf-8") as f: + for data in tqdm(corpus, desc="Loading and Saving corpus"): + docid, text = str(data["id"]), data["content"] + _data = { + "id": docid, + "text": text + } + corpus_dict[docid] = {"text": text} + f.write(json.dumps(_data, ensure_ascii=False) + "\n") + logging.info(f"{self.eval_name} {dataset_name} corpus saved to {save_path}") + else: + corpus_dict = {str(data["id"]): {"text": data["content"]} for data in tqdm(corpus, desc="Loading corpus")} + return datasets.DatasetDict(corpus_dict) + + def _load_remote_qrels( + self, + dataset_name: str, + split: str = 'examples', + save_dir: Optional[str] = None + ) -> datasets.DatasetDict: + """Load the qrels from HF. + + Args: + dataset_name (str): Name of the dataset. + split (str, optional): Split of the dataset. Defaults to ``'examples'``. + save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``. + + Returns: + datasets.DatasetDict: Loaded datasets instance of qrel. + """ + examples = datasets.load_dataset( + "xlangai/bright", split, + cache_dir=self.cache_dir, + download_mode=self.hf_download_mode + )[dataset_name] + + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, f"{split}_qrels.jsonl") + qrels_dict = defaultdict(dict) + with open(save_path, "w", encoding="utf-8") as f: + for data in tqdm(examples, desc="Loading and Saving qrels"): + + # NOTE: we modify the qid here to distinguish the queries from different splits + qid = f'{split}-{data["id"]}' + + for docid in data["gold_ids"]: + _data = { + "qid": qid, + "docid": docid, + "relevance": 1 + } + qrels_dict[qid][docid] = 1 + f.write(json.dumps(_data, ensure_ascii=False) + "\n") + + # NOTE: we record the excluded_ids in qrels with relevance 0 to remove corresponding documents from raw search results. Refer to `searcher.py` for details. + for ex_docid in list(set(data["excluded_ids"])): + if ex_docid == "N/A": + continue + assert ex_docid not in qrels_dict[qid], f"{ex_docid} in {qid}" + _data = { + "qid": qid, + "docid": ex_docid, + "relevance": 0 + } + qrels_dict[qid][ex_docid] = 0 + f.write(json.dumps(_data, ensure_ascii=False) + "\n") + else: + qrels_dict = defaultdict(dict) + for data in tqdm(examples, desc="Loading qrels"): + + # NOTE: we modify the qid here to distinguish the queries from different splits + qid = f'{split}-{data["id"]}' + + for docid in data["gold_ids"]: + qrels_dict[qid][docid] = 1 + + # NOTE: we record the excluded_ids in qrels with relevance 0 to remove corresponding documents from raw search results. Refer to `searcher.py` for details. + for ex_docid in data["excluded_ids"]: + if ex_docid == "N/A": + continue + assert ex_docid not in qrels_dict[qid], f"{ex_docid} in {qid}" + _data = { + "qid": qid, + "docid": ex_docid, + "relevance": 0 + } + qrels_dict[qid][ex_docid] = 0 + return datasets.DatasetDict(qrels_dict) + + def _load_remote_queries( + self, + dataset_name: str, + split: str = 'examples', + save_dir: Optional[str] = None + ) -> datasets.DatasetDict: + """Load the queries from HF. + + Args: + dataset_name (str): Name of the dataset. + split (str, optional): Split of the dataset. Defaults to ``'examples'``. + save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``. + + Returns: + datasets.DatasetDict: Loaded datasets instance of queries. + """ + examples = datasets.load_dataset( + "xlangai/bright", split, + cache_dir=self.cache_dir, + download_mode=self.hf_download_mode + )[dataset_name] + + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, f"{split}_queries.jsonl") + queries_dict = {} + with open(save_path, "w", encoding="utf-8") as f: + for data in tqdm(examples, desc="Loading and Saving queries"): + + # NOTE: we modify the qid here to distinguish the queries from different splits + qid, query = f'{split}-{data["id"]}', data["query"] + + _data = { + "id": qid, + "text": query + } + queries_dict[qid] = query + f.write(json.dumps(_data, ensure_ascii=False) + "\n") + else: + # NOTE: we modify the qid here to distinguish the queries from different splits + queries_dict = {f'{split}-{data["id"]}': data["query"] for data in tqdm(examples, desc="Loading queries")} + return datasets.DatasetDict(queries_dict) + + +class BrightLongEvalDataLoader(AbsEvalDataLoader): + """ + Data loader class for Bright(long). + """ + def available_dataset_names(self) -> List[str]: + """ + Get the available dataset names. + + Returns: + List[str]: All the available dataset names. + """ + return [ + # StackExchange + "biology", "earth_science", "economics", "psychology", "robotics", "stackoverflow", "sustainable_living", + # Coding + "pony", + ] + + def available_splits(self, dataset_name: str) -> List[str]: + """ + Get the avaialble splits. + + Args: + dataset_name (str): Dataset name. + + Returns: + List[str]: All the available splits for the dataset. + """ + return [ + # normal splits + "examples", + # w/ reasoning splits + "Gemini-1.0_reason", "claude-3-opus_reason", "gpt4_reason", "grit_reason", "llama3-70b_reason", + ] + + def _load_remote_corpus( + self, + dataset_name: str, + save_dir: Optional[str] = None + ) -> datasets.DatasetDict: + """Load the corpus dataset from HF. + + Args: + dataset_name (str): Name of the dataset. + save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``. + + Returns: + datasets.DatasetDict: Loaded datasets instance of corpus. + """ + corpus = datasets.load_dataset( + "xlangai/bright", "long_documents", + cache_dir=self.cache_dir, + download_mode=self.hf_download_mode + )[dataset_name] + + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, "corpus.jsonl") + corpus_dict = {} + with open(save_path, "w", encoding="utf-8") as f: + for data in tqdm(corpus, desc="Loading and Saving corpus"): + docid, text = str(data["id"]), data["content"] + _data = { + "id": docid, + "text": text + } + corpus_dict[docid] = {"text": text} + f.write(json.dumps(_data, ensure_ascii=False) + "\n") + logging.info(f"{self.eval_name} {dataset_name} corpus saved to {save_path}") + else: + corpus_dict = {str(data["id"]): {"text": data["content"]} for data in tqdm(corpus, desc="Loading corpus")} + return datasets.DatasetDict(corpus_dict) + + def _load_remote_qrels( + self, + dataset_name: str, + split: str = 'examples', + save_dir: Optional[str] = None + ) -> datasets.DatasetDict: + """Load the qrels from HF. + + Args: + dataset_name (str): Name of the dataset. + split (str, optional): Split of the dataset. Defaults to ``'examples'``. + save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``. + + Returns: + datasets.DatasetDict: Loaded datasets instance of qrel. + """ + examples = datasets.load_dataset( + "xlangai/bright", split, + cache_dir=self.cache_dir, + download_mode=self.hf_download_mode + )[dataset_name] + + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, f"{split}_qrels.jsonl") + qrels_dict = defaultdict(dict) + with open(save_path, "w", encoding="utf-8") as f: + for data in tqdm(examples, desc="Loading and Saving qrels"): + + # NOTE: we modify the qid here to distinguish the queries from different splits + qid = f'{split}-{data["id"]}' + + for docid in data["gold_ids_long"]: + _data = { + "qid": qid, + "docid": docid, + "relevance": 1 + } + qrels_dict[qid][docid] = 1 + f.write(json.dumps(_data, ensure_ascii=False) + "\n") + + # NOTE: we record the excluded_ids in qrels with relevance 0 to remove corresponding documents from raw search results. Refer to `searcher.py` for details. + for ex_docid in list(set(data["excluded_ids"])): + if ex_docid == "N/A": + continue + assert ex_docid not in qrels_dict[qid], f"{ex_docid} in {qid}" + _data = { + "qid": qid, + "docid": ex_docid, + "relevance": 0 + } + qrels_dict[qid][ex_docid] = 0 + f.write(json.dumps(_data, ensure_ascii=False) + "\n") + else: + qrels_dict = defaultdict(dict) + for data in tqdm(examples, desc="Loading qrels"): + + # NOTE: we modify the qid here to distinguish the queries from different splits + qid = f'{split}-{data["id"]}' + + for docid in data["gold_ids_long"]: + qrels_dict[qid][docid] = 1 + + # NOTE: we record the excluded_ids in qrels with relevance 0 to remove corresponding documents from raw search results. Refer to `searcher.py` for details. + for ex_docid in data["excluded_ids"]: + if ex_docid == "N/A": + continue + assert ex_docid not in qrels_dict[qid], f"{ex_docid} in {qid}" + _data = { + "qid": qid, + "docid": ex_docid, + "relevance": 0 + } + qrels_dict[qid][ex_docid] = 0 + return datasets.DatasetDict(qrels_dict) + + def _load_remote_queries( + self, + dataset_name: str, + split: str = 'examples', + save_dir: Optional[str] = None + ) -> datasets.DatasetDict: + """Load the queries from HF. + + Args: + dataset_name (str): Name of the dataset. + split (str, optional): Split of the dataset. Defaults to ``'examples'``. + save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``. + + Returns: + datasets.DatasetDict: Loaded datasets instance of queries. + """ + examples = datasets.load_dataset( + "xlangai/bright", split, + cache_dir=self.cache_dir, + download_mode=self.hf_download_mode + )[dataset_name] + + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, f"{split}_queries.jsonl") + queries_dict = {} + with open(save_path, "w", encoding="utf-8") as f: + for data in tqdm(examples, desc="Loading and Saving queries"): + + # NOTE: we modify the qid here to distinguish the queries from different splits + qid, query = f'{split}-{data["id"]}', data["query"] + + _data = { + "id": qid, + "text": query + } + queries_dict[qid] = query + f.write(json.dumps(_data, ensure_ascii=False) + "\n") + else: + # NOTE: we modify the qid here to distinguish the queries from different splits + queries_dict = {f'{split}-{data["id"]}': data["query"] for data in tqdm(examples, desc="Loading queries")} + return datasets.DatasetDict(queries_dict) diff --git a/FlagEmbedding/evaluation/bright/prompts.py b/FlagEmbedding/evaluation/bright/prompts.py new file mode 100644 index 00000000..5a08fae9 --- /dev/null +++ b/FlagEmbedding/evaluation/bright/prompts.py @@ -0,0 +1,31 @@ +BrightShortInstructions = { + # StackExchange + "biology": "Given a Biology post, retrieve relevant passages that help answer the post.", + "earth_science": "Given an Earth Science post, retrieve relevant passages that help answer the post.", + "economics": "Given an Economics post, retrieve relevant passages that help answer the post.", + "psychology": "Given a Psychology post, retrieve relevant passages that help answer the post.", + "robotics": "Given a Robotics post, retrieve relevant passages that help answer the post.", + "stackoverflow": "Given a Stack Overflow post, retrieve relevant passages that help answer the post.", + "sustainable_living": "Given a Sustainable Living post, retrieve relevant passages that help answer the post.", + # Coding + "leetcode": "Given a Coding problem, retrieve relevant examples that help answer the problem.", + "pony": "Given a Pony question, retrieve relevant passages that help answer the question.", + # Theorem-based + "aops": "Given a Math problem, retrieve relevant examples that help answer the problem.", + "theoremqa_questions": "Given a Math problem, retrieve relevant examples that help answer the problem.", + "theoremqa_theorems": "Given a Math problem, retrieve relevant theorems that help answer the problem.", +} + + +BrightLongInstructions = { + # StackExchange + "biology": "Given a Biology post, retrieve relevant documents that help answer the post.", + "earth_science": "Given an Earth Science post, retrieve relevant documents that help answer the post.", + "economics": "Given an Economics post, retrieve relevant documents that help answer the post.", + "psychology": "Given a Psychology post, retrieve relevant documents that help answer the post.", + "robotics": "Given a Robotics post, retrieve relevant documents that help answer the post.", + "stackoverflow": "Given a Stack Overflow post, retrieve relevant documents that help answer the post.", + "sustainable_living": "Given a Sustainable Living post, retrieve relevant documents that help answer the post.", + # Coding + "pony": "Given a Pony question, retrieve relevant documents that help answer the question", +} diff --git a/FlagEmbedding/evaluation/bright/runner.py b/FlagEmbedding/evaluation/bright/runner.py new file mode 100644 index 00000000..d7441443 --- /dev/null +++ b/FlagEmbedding/evaluation/bright/runner.py @@ -0,0 +1,119 @@ +import logging +from typing import Union, Tuple +from FlagEmbedding.abc.evaluation import AbsEvalRunner, EvalReranker, \ + AbsEvalModelArgs as BrightEvalModelArgs + +from .prompts import BrightShortInstructions, BrightLongInstructions +from .arguments import BrightEvalArgs +from .data_loader import BrightShortEvalDataLoader, BrightLongEvalDataLoader +from .searcher import BrightEvalDenseRetriever + +logger = logging.getLogger(__name__) + + +class BrightEvalRunner(AbsEvalRunner): + """ + Evaluation runner of Bright. + """ + def __init__(self, eval_args: BrightEvalArgs, model_args: BrightEvalModelArgs): + super().__init__(eval_args, model_args) + self.eval_args: BrightEvalArgs + self.model_args: BrightEvalModelArgs + + def load_data_loader(self) -> Union[BrightShortEvalDataLoader, BrightLongEvalDataLoader]: + """Load the data loader instance by args. + + Returns: + Union[BrightShortEvalDataLoader, BrightLongEvalDataLoader]: The Bright data loader instance. + """ + if self.eval_args.task_type == "short": + data_loader_class = BrightShortEvalDataLoader + elif self.eval_args.task_type == "long": + data_loader_class = BrightLongEvalDataLoader + else: + raise ValueError(f"Invalid task type: {self.eval_args.task_type}") + + data_loader = data_loader_class( + eval_name=self.eval_args.eval_name, + dataset_dir=self.eval_args.dataset_dir, + cache_dir=self.eval_args.cache_path, + token=self.eval_args.token, + force_redownload=self.eval_args.force_redownload, + ) + return data_loader + + def load_retriever_and_reranker(self) -> Tuple[BrightEvalDenseRetriever, Union[EvalReranker, None]]: + """Load retriever and reranker for evaluation + + Returns: + Tuple[BrightEvalDenseRetriever, Union[EvalReranker, None]]: A :class:BrightEvalDenseRetriever object for retrieval, and a + :class:EvalReranker object if reranker provided. + """ + embedder, reranker = self.get_models(self.model_args) + retriever = BrightEvalDenseRetriever( + embedder, + search_top_k=self.eval_args.search_top_k, + overwrite=self.eval_args.overwrite + ) + if reranker is not None: + reranker = EvalReranker(reranker, rerank_top_k=self.eval_args.rerank_top_k) + return retriever, reranker + + def run(self): + """ + Run the whole evaluation. + """ + if self.eval_args.dataset_names is None: + dataset_names = self.data_loader.available_dataset_names() + else: + dataset_names = self.data_loader.check_dataset_names(self.eval_args.dataset_names) + + if len(dataset_names) == 0: + logger.info(f"Running {self.eval_args.eval_name} evaluation on the default dataset.") + self.evaluator( + splits=self.eval_args.splits, + search_results_save_dir=self.eval_args.output_dir, + retriever=self.retriever, + reranker=self.reranker, + corpus_embd_save_dir=self.eval_args.corpus_embd_save_dir, + ignore_identical_ids=self.eval_args.ignore_identical_ids, + k_values=self.eval_args.k_values + ) + logger.info(f"{self.eval_args.eval_name} evaluation completed.") + else: + logger.info(f"Running {self.eval_args.eval_name} evaluation on the following dataset names: {dataset_names}") + for dataset_name in dataset_names: + if self.eval_args.use_special_instructions: + self.retriever.stop_multi_process_pool() + if self.eval_args.task_type == "short": + self.retriever.embedder.query_instruction_for_retrieval = BrightShortInstructions[dataset_name] + elif self.eval_args.task_type == "long": + self.retriever.embedder.query_instruction_for_retrieval = BrightLongInstructions[dataset_name] + else: + raise ValueError(f"Invalid task type: {self.eval_args.task_type}") + + # NOTE: pass qrels to searcher to exclude documents from raw search results + evaluator_kwargs = {} + evaluator_kwargs["retriever_qrels"] = self.data_loader.load_qrels(dataset_name=dataset_name, split=self.eval_args.splits) + + logger.info(f"Running {self.eval_args.eval_name} evaluation on: {dataset_name}") + self.evaluator( + splits=self.eval_args.splits, + search_results_save_dir=self.eval_args.output_dir, + retriever=self.retriever, + reranker=self.reranker, + corpus_embd_save_dir=self.eval_args.corpus_embd_save_dir, + ignore_identical_ids=self.eval_args.ignore_identical_ids, + k_values=self.eval_args.k_values, + dataset_name=dataset_name, + **evaluator_kwargs, + ) + logger.info(f"{self.eval_args.eval_name} evaluation on {dataset_names} completed.") + + logger.info("Start computing metrics.") + self.evaluate_metrics( + search_results_save_dir=self.eval_args.output_dir, + output_method=self.eval_args.eval_output_method, + output_path=self.eval_args.eval_output_path, + metrics=self.eval_args.eval_metrics + ) diff --git a/FlagEmbedding/evaluation/bright/searcher.py b/FlagEmbedding/evaluation/bright/searcher.py new file mode 100644 index 00000000..90856e89 --- /dev/null +++ b/FlagEmbedding/evaluation/bright/searcher.py @@ -0,0 +1,127 @@ +import os +import logging +import gc +import torch +import numpy as np +from typing import Any, Dict, Optional + +from FlagEmbedding.abc.evaluation.utils import index, search + +from FlagEmbedding.abc.evaluation import EvalRetriever + +logger = logging.getLogger(__name__) + + +class BrightEvalDenseRetriever(EvalRetriever): + """ + Child class of :class:EvalRetriever for dense retrieval. + """ + def __call__( + self, + corpus: Dict[str, Dict[str, Any]], + queries: Dict[str, str], + corpus_embd_save_dir: Optional[str] = None, + ignore_identical_ids: bool = False, + **kwargs, + ) -> Dict[str, Dict[str, float]]: + """ + This is called during the retrieval process. + + Parameters: + corpus: Dict[str, Dict[str, Any]]: Corpus of documents. + Structure: {: {"text": }}. + Example: {"doc-0": {"text": "This is a document."}} + queries: Dict[str, str]: Queries to search for. + Structure: {: }. + Example: {"q-0": "This is a query."} + corpus_embd_save_dir (Optional[str]): Defaults to :data:`None`. + ignore_identical_ids (bool): Defaults to :data:`False`. + **kwargs: Any: Additional arguments. + + Returns: Dict[str, Dict[str, float]]: Top-k search results for each query. k is specified by search_top_k. + Structure: {qid: {docid: score}}. The higher is the score, the more relevant is the document. + Example: {"q-0": {"doc-0": 0.9}} + """ + if ignore_identical_ids: + logger.warning("ignore_identical_ids is set to True. This means that the search results will not contain identical ids. Note: Dataset such as MIRACL should NOT set this to True.") + + # dense embedding models do not require language as input: AIRBench evaluation + kwargs.pop("language", None) + + corpus_ids = [] + corpus_texts = [] + for docid, doc in corpus.items(): + corpus_ids.append(docid) + corpus_texts.append( + doc["text"] if "title" not in doc + else f"{doc['title']} {doc['text']}".strip() + ) + queries_ids = [] + queries_texts = [] + for qid, query in queries.items(): + queries_ids.append(qid) + queries_texts.append(query) + + # NOTE: obtain excluded ids from qrels to remove corresponding documents from raw search results + excluded_ids = {} + qrels = kwargs.pop("retriever_qrels", None) + if qrels is not None: + for qid in qrels: + excluded_ids[qid] = [] + for docid, score in qrels[qid].items(): + if score != 1: + excluded_ids[qid].append(docid) + else: + logger.warning("No qrels provided, so no documents will be excluded.") + + if corpus_embd_save_dir is not None: + if os.path.exists(os.path.join(corpus_embd_save_dir, "doc.npy")) and not self.overwrite: + corpus_emb = np.load(os.path.join(corpus_embd_save_dir, "doc.npy")) + else: + corpus_emb = self.embedder.encode_corpus(corpus_texts, **kwargs) + else: + corpus_emb = self.embedder.encode_corpus(corpus_texts, **kwargs) + + queries_emb = self.embedder.encode_queries(queries_texts, **kwargs) + + # check if the embeddings are in dictionary format: M3Embedder + if isinstance(corpus_emb, dict): + corpus_emb = corpus_emb["dense_vecs"] + if isinstance(queries_emb, dict): + queries_emb = queries_emb["dense_vecs"] + + if corpus_embd_save_dir is not None and \ + (not os.path.exists(os.path.join(corpus_embd_save_dir, "doc.npy")) or self.overwrite): + os.makedirs(corpus_embd_save_dir, exist_ok=True) + np.save(os.path.join(corpus_embd_save_dir, "doc.npy"), corpus_emb) + + gc.collect() + torch.cuda.empty_cache() + + faiss_index = index(corpus_embeddings=corpus_emb) + all_scores, all_indices = search(query_embeddings=queries_emb, faiss_index=faiss_index, k=self.search_top_k) + + results = {} + for idx, (scores, indices) in enumerate(zip(all_scores, all_indices)): + query_id = queries_ids[idx] + + results[query_id] = {} + for score, indice in zip(scores, indices): + if indice != -1: + if ignore_identical_ids and corpus_ids[indice] == query_id: + continue + results[query_id][corpus_ids[indice]] = float(score) + + if qrels is not None: + # NOTE: Filter out documents with ids in excluded_ids + for docid in set(excluded_ids[query_id]): + if docid != "N/A": + results[query_id].pop(docid, None) + + sorted_scores = sorted(results[query_id].items(), key=lambda item: item[1], reverse=True) + # Store the top-k results for the current query + results[query_id] = {} + for docid, score in sorted_scores[:self.search_top_k]: + results[query_id][docid] = float(score) + + return results diff --git a/examples/README.md b/examples/README.md index 1b4de2a3..d115b8fc 100644 --- a/examples/README.md +++ b/examples/README.md @@ -154,7 +154,7 @@ torchrun --nproc_per_node 2 \ ## 5. Evaluation -We support evaluations on [MTEB](https://github.com/embeddings-benchmark/mteb), [BEIR](https://github.com/beir-cellar/beir), [MSMARCO](https://microsoft.github.io/msmarco/), [MIRACL](https://github.com/project-miracl/miracl), [MLDR](https://huggingface.co/datasets/Shitao/MLDR), [MKQA](https://github.com/apple/ml-mkqa), [AIR-Bench](https://github.com/AIR-Bench/AIR-Bench), and custom datasets. Below is an example of evaluating MSMARCO passages. For more details, please refer to the [evaluation examples](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/evaluation). +We support evaluations on [MTEB](https://github.com/embeddings-benchmark/mteb), [BEIR](https://github.com/beir-cellar/beir), [MSMARCO](https://microsoft.github.io/msmarco/), [MIRACL](https://github.com/project-miracl/miracl), [MLDR](https://huggingface.co/datasets/Shitao/MLDR), [MKQA](https://github.com/apple/ml-mkqa), [AIR-Bench](https://github.com/AIR-Bench/AIR-Bench), [BRIGHT](https://brightbenchmark.github.io/), and custom datasets. Below is an example of evaluating MSMARCO passages. For more details, please refer to the [evaluation examples](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/evaluation). ```shell pip install pytrec_eval diff --git a/examples/evaluation/README.md b/examples/evaluation/README.md index 597b1342..ea88fc6c 100644 --- a/examples/evaluation/README.md +++ b/examples/evaluation/README.md @@ -19,8 +19,9 @@ In this section, we will first introduce the commonly used arguments across all - [(4) MIRACL](#4-MIRACL) - [(5) MLDR](#5-MLDR) - [(6) MKQA](#6-MKQA) - - [(7) AIR-Bench](#7-Air-Bench) - - [(8) Custom Dataset](#8-Custom-Dataset) + - [(7) AIR-Bench](#7-AIR-Bench) + - [(8) BRIGHT](#8-BRIGHT) + - [(9) Custom Dataset](#9-Custom-Dataset) ## Introduction @@ -320,7 +321,44 @@ python -m FlagEmbedding.evaluation.air_bench \ --reranker_max_length 1024 ``` -### 8. Custom Dataset +### 8. BRIGHT + +[BRIGHT](https://brightbenchmark.github.io/) supports evaluations on reasoning-intensive text retrieval tasks, and includes 12 datasets: `biology`, `earth_science`, `economics`, `psychology`, `robotics`, `stackoverflow`, `sustainable_living`, `leetcode`, `pony`, `aops`, `theoremqa_questions`, `theoremqa_theorems`. + +Here is an example for evaluation: + +```shell +# Available splits (provided by BRIGHT): examples, gpt4_reason, claude-3-opus_reason, Gemini-1.0_reason, llama3-70b_reason, grit_reason +# NOTE: must run single split each time to ensure correct format of output evaluation results +split="examples" +python -m FlagEmbedding.evaluation.bright \ + --task_type short \ + --use_special_instructions True \ + --eval_name bright_short \ + --dataset_dir ./bright_short/data \ + --dataset_names pony theoremqa_theorems \ + --splits $split \ + --corpus_embd_save_dir ./bright_short/corpus_embd \ + --output_dir ./bright_short/search_results/$split \ + --search_top_k 2000 \ + --cache_path ./cache/data \ + --overwrite False \ + --k_values 1 10 100 \ + --eval_output_method markdown \ + --eval_output_path ./bright_short/eval_results_$split.md \ + --eval_metrics ndcg_at_10 recall_at_10 recall_at_100 \ + --embedder_name_or_path BAAI/bge-reasoner-embed-qwen3-8b-0923 \ + --embedder_model_class decoder-only-base \ + --query_instruction_format_for_retrieval 'Instruct: {}\nQuery: {}' \ + --pooling_method last_token \ + --devices cuda:0 cuda:1 \ + --cache_dir ./cache/model \ + --embedder_batch_size 2 \ + --embedder_query_max_length 8192 \ + --embedder_passage_max_length 8192 \ +``` + +### 9. Custom Dataset The example data for `corpus.jsonl`: diff --git a/examples/evaluation/bright/eval_bright_short.sh b/examples/evaluation/bright/eval_bright_short.sh new file mode 100644 index 00000000..190ca79e --- /dev/null +++ b/examples/evaluation/bright/eval_bright_short.sh @@ -0,0 +1,52 @@ +if [ -z "$HF_HUB_CACHE" ]; then + export HF_HUB_CACHE="$HOME/.cache/huggingface/hub" +fi + +# full datasets +# dataset_names="biology earth_science economics psychology robotics stackoverflow sustainable_living leetcode pony aops theoremqa_questions theoremqa_theorems" + +# small datasets for quick test +dataset_names="pony theoremqa_theorems" + +model_args="\ + --embedder_name_or_path BAAI/bge-reasoner-embed-qwen3-8b-0923 \ + --embedder_model_class decoder-only-base \ + --query_instruction_format_for_retrieval 'Instruct: {}\nQuery: {}' \ + --pooling_method last_token \ + --devices cuda:0 cuda:1 \ + --cache_dir $HF_HUB_CACHE \ + --embedder_batch_size 2 \ + --embedder_query_max_length 8192 \ + --embedder_passage_max_length 8192 \ +" + +split_list=("examples" "gpt4_reason") + +for split in "${split_list[@]}"; do + eval_args="\ + --task_type short \ + --use_special_instructions True \ + --eval_name bright_short \ + --dataset_dir ./bright_short/data \ + --dataset_names $dataset_names \ + --splits $split \ + --corpus_embd_save_dir ./bright_short/corpus_embd \ + --output_dir ./bright_short/search_results/$split \ + --search_top_k 2000 \ + --cache_path $HF_HUB_CACHE \ + --overwrite False \ + --k_values 1 10 100 \ + --eval_output_method markdown \ + --eval_output_path ./bright_short/eval_results_$split.md \ + --eval_metrics ndcg_at_10 recall_at_10 recall_at_100 \ + " + + cmd="python -m FlagEmbedding.evaluation.bright \ + $eval_args \ + $model_args \ + " + + echo $cmd + eval $cmd + +done