|
| 1 | +"""Document Reader Experiments""" |
| 2 | + |
| 3 | +import argparse |
| 4 | +from haystack import Pipeline |
| 5 | +from haystack.nodes import FARMReader |
| 6 | +# from haystack.utils import print_answers |
| 7 | + |
| 8 | +from .module import Dataset, DocReader, Sports |
| 9 | +from .module import ( |
| 10 | + SQuadDataset, |
| 11 | + AdversarialQADataset, |
| 12 | + DuoRCDataset, |
| 13 | + QASportsDataset, |
| 14 | +) |
| 15 | + |
| 16 | + |
| 17 | +# Model setup |
| 18 | +# DATASET = Dataset.QASports |
| 19 | +# DOC_READER = DocReader.BERT |
| 20 | +# SPORT = Sports.SKIING |
| 21 | +parser = argparse.ArgumentParser(description="Run document reader experiments.") |
| 22 | +parser.add_argument( |
| 23 | + "--dataset", |
| 24 | + type=str, |
| 25 | + default="QASports", |
| 26 | + choices=[d.name for d in Dataset], |
| 27 | + help="Dataset to use for the experiment.", |
| 28 | +) |
| 29 | +parser.add_argument( |
| 30 | + "--model", |
| 31 | + type=str, |
| 32 | + default="BERT", |
| 33 | + choices=[attr.name for attr in DocReader], |
| 34 | + help="Document reader model to use.", |
| 35 | +) |
| 36 | +parser.add_argument( |
| 37 | + "--sport", |
| 38 | + type=str, |
| 39 | + default="ALL", |
| 40 | + choices=[attr.name for attr in Sports], |
| 41 | + help="Sport to filter for QASports dataset.", |
| 42 | +) |
| 43 | + |
| 44 | +args = parser.parse_args() |
| 45 | + |
| 46 | +DATASET = Dataset[args.dataset] |
| 47 | +DOC_READER = DocReader[args.model].value |
| 48 | +SPORT = Sports[args.sport].value |
| 49 | +print(f"Dataset: {DATASET} // Sport: {SPORT}") |
| 50 | +print(f"Model: {DOC_READER}") |
| 51 | + |
| 52 | + |
| 53 | +# Download the dataset |
| 54 | +def dataset_switch(choice): |
| 55 | + """Get dataset class""" |
| 56 | + |
| 57 | + if choice == Dataset.SQuAD: |
| 58 | + return SQuadDataset() |
| 59 | + elif choice == Dataset.AdvQA: |
| 60 | + return AdversarialQADataset() |
| 61 | + elif choice == Dataset.DuoRC: |
| 62 | + return DuoRCDataset() |
| 63 | + elif choice == Dataset.QASports: |
| 64 | + return QASportsDataset(SPORT) |
| 65 | + else: |
| 66 | + return "Invalid dataset" |
| 67 | + |
| 68 | + |
| 69 | +# Get the dataset |
| 70 | +dataset = dataset_switch(DATASET) |
| 71 | +docs = dataset.get_documents() |
| 72 | + |
| 73 | +"""--- |
| 74 | +## Document Reader |
| 75 | +
|
| 76 | +In this experiment, we explored three Transformer based models for extractive Question Answering using the [FARM framework](https://github.com/deepset-ai/FARM). |
| 77 | +* [BERT paper](https://arxiv.org/abs/1810.04805), [implementation](https://huggingface.co/deepset/bert-base-uncased-squad2) |
| 78 | +* [RoBERTa paper](https://arxiv.org/abs/1907.11692), [implementation](https://huggingface.co/deepset/roberta-base-squad2) |
| 79 | +* [MiniLM paper](https://arxiv.org/abs/2002.10957), [implementation](https://huggingface.co/deepset/minilm-uncased-squad2) |
| 80 | +
|
| 81 | +""" |
| 82 | + |
| 83 | +# Get the reader |
| 84 | +reader = FARMReader(DOC_READER, use_gpu=True) |
| 85 | + |
| 86 | +# Build the pipeline |
| 87 | +pipe = Pipeline() |
| 88 | +pipe.add_node(component=reader, name="Reader", inputs=["Query"]) |
| 89 | + |
| 90 | +# # Querying documents |
| 91 | +# question = "Who did the Raptors face in the first round of the 2015 Playoffs?" |
| 92 | +# prediction = pipe.run( |
| 93 | +# query=question, documents=docs[0:10], params={"Reader": {"top_k": 3}} |
| 94 | +# ) |
| 95 | + |
| 96 | +# # Print answer |
| 97 | +# print_answers(prediction) |
| 98 | + |
| 99 | +"""--- |
| 100 | +## Evaluation |
| 101 | +
|
| 102 | +About the metrics, you can read the [evaluation](https://docs.haystack.deepset.ai/docs/evaluation) web page. |
| 103 | +""" |
| 104 | + |
| 105 | +# Commented out IPython magic to ensure Python compatibility. |
| 106 | +# %%time |
| 107 | +# |
| 108 | +# For testing purposes, running on the first 100 labels |
| 109 | +# For real evaluation, you must remove the [0:100] |
| 110 | +eval_labels = dataset.get_validation() |
| 111 | +eval_docs = [ |
| 112 | + [label.document for label in multi_label.labels] for multi_label in eval_labels |
| 113 | +] |
| 114 | + |
| 115 | +eval_result = pipe.eval( |
| 116 | + labels=eval_labels, documents=eval_docs, params={"Reader": {"top_k": 1}} |
| 117 | +) |
| 118 | + |
| 119 | +# Get and print the metrics |
| 120 | +metrics = eval_result.calculate_metrics() |
| 121 | +print(metrics) |
0 commit comments