55from sklearn .metrics .pairwise import cosine_similarity
66import numpy as np
77
8- from haystack import MultiLabel , Label
8+ from haystack import MultiLabel , Label , BaseComponent , Document
99
1010from farm .evaluation .squad_evaluation import compute_f1 as calculate_f1_str
1111from farm .evaluation .squad_evaluation import compute_exact as calculate_em_str
1212
1313logger = logging .getLogger (__name__ )
1414
1515
16- class EvalDocuments :
16+ class EvalDocuments ( BaseComponent ) :
1717 """
1818 This is a pipeline node that should be placed after a node that returns a List of Document, e.g., Retriever or
1919 Ranker, in order to assess its performance. Performance metrics are stored in this class and updated as each
@@ -22,21 +22,22 @@ class EvalDocuments:
2222 a look at our evaluation tutorial for more info about open vs closed domain eval (
2323 https://haystack.deepset.ai/tutorials/evaluation).
2424 """
25- def __init__ (self , debug : bool = False , open_domain : bool = True , top_k_eval_documents : int = 10 , name = "EvalDocuments" ):
25+
26+ outgoing_edges = 1
27+
28+ def __init__ (self , debug : bool = False , open_domain : bool = True , top_k : int = 10 ):
2629 """
2730 :param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it.
2831 When False, correct retrieval is evaluated based on document_id.
2932 :param debug: When True, a record of each sample and its evaluation will be stored in EvalDocuments.log
3033 :param top_k: calculate eval metrics for top k results, e.g., recall@k
3134 """
32- self .outgoing_edges = 1
3335 self .init_counts ()
3436 self .no_answer_warning = False
3537 self .debug = debug
3638 self .log : List = []
3739 self .open_domain = open_domain
38- self .top_k_eval_documents = top_k_eval_documents
39- self .name = name
40+ self .top_k = top_k
4041 self .too_few_docs_warning = False
4142 self .top_k_used = 0
4243
@@ -53,25 +54,25 @@ def init_counts(self):
5354 self .reciprocal_rank_sum = 0.0
5455 self .has_answer_reciprocal_rank_sum = 0.0
5556
56- def run (self , documents , labels : dict , top_k_eval_documents : Optional [int ]= None , ** kwargs ):
57+ def run (self , documents : List [ Document ] , labels : List [ Label ], top_k : Optional [int ] = None ): # type: ignore
5758 """Run this node on one sample and its labels"""
5859 self .query_count += 1
59- retriever_labels = get_label (labels , kwargs [ "node_id" ] )
60- if not top_k_eval_documents :
61- top_k_eval_documents = self .top_k_eval_documents
60+ retriever_labels = get_label (labels , self . name )
61+ if not top_k :
62+ top_k = self .top_k
6263
6364 if not self .top_k_used :
64- self .top_k_used = top_k_eval_documents
65- elif self .top_k_used != top_k_eval_documents :
65+ self .top_k_used = top_k
66+ elif self .top_k_used != top_k :
6667 logger .warning (f"EvalDocuments was last run with top_k_eval_documents={ self .top_k_used } but is "
67- f"being run again with top_k_eval_documents ={ self .top_k_eval_documents } . "
68+ f"being run again with top_k ={ self .top_k } . "
6869 f"The evaluation counter is being reset from this point so that the evaluation "
6970 f"metrics are interpretable." )
7071 self .init_counts ()
7172
72- if len (documents ) < top_k_eval_documents and not self .too_few_docs_warning :
73- logger .warning (f"EvalDocuments is being provided less candidate documents than top_k_eval_documents "
74- f"(currently set to { top_k_eval_documents } )." )
73+ if len (documents ) < top_k and not self .too_few_docs_warning :
74+ logger .warning (f"EvalDocuments is being provided less candidate documents than top_k "
75+ f"(currently set to { top_k } )." )
7576 self .too_few_docs_warning = True
7677
7778 # TODO retriever_labels is currently a Multilabel object but should eventually be a RetrieverLabel object
@@ -89,7 +90,7 @@ def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None,
8990 # If there are answer span annotations in the labels
9091 else :
9192 self .has_answer_count += 1
92- retrieved_reciprocal_rank = self .reciprocal_rank_retrieved (retriever_labels , documents , top_k_eval_documents )
93+ retrieved_reciprocal_rank = self .reciprocal_rank_retrieved (retriever_labels , documents , top_k )
9394 self .reciprocal_rank_sum += retrieved_reciprocal_rank
9495 correct_retrieval = True if retrieved_reciprocal_rank > 0 else False
9596 self .has_answer_correct += int (correct_retrieval )
@@ -101,11 +102,11 @@ def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None,
101102 self .recall = self .correct_retrieval_count / self .query_count
102103 self .mean_reciprocal_rank = self .reciprocal_rank_sum / self .query_count
103104
104- self .top_k_used = top_k_eval_documents
105+ self .top_k_used = top_k
105106
106107 if self .debug :
107- self .log .append ({"documents" : documents , "labels" : labels , "correct_retrieval" : correct_retrieval , "retrieved_reciprocal_rank" : retrieved_reciprocal_rank , ** kwargs })
108- return {"documents" : documents , "labels" : labels , " correct_retrieval" : correct_retrieval , "retrieved_reciprocal_rank" : retrieved_reciprocal_rank , ** kwargs }, "output_1"
108+ self .log .append ({"documents" : documents , "labels" : labels , "correct_retrieval" : correct_retrieval , "retrieved_reciprocal_rank" : retrieved_reciprocal_rank })
109+ return {"correct_retrieval" : correct_retrieval }, "output_1"
109110
110111 def is_correctly_retrieved (self , retriever_labels , predictions ):
111112 return self .reciprocal_rank_retrieved (retriever_labels , predictions ) > 0
@@ -142,7 +143,7 @@ def print(self):
142143 print (f"mean_reciprocal_rank@{ self .top_k_used } : { self .mean_reciprocal_rank :.4f} " )
143144
144145
145- class EvalAnswers :
146+ class EvalAnswers ( BaseComponent ) :
146147 """
147148 This is a pipeline node that should be placed after a Reader in order to assess the performance of the Reader
148149 individually or to assess the extractive QA performance of the whole pipeline. Performance metrics are stored in
@@ -152,6 +153,8 @@ class EvalAnswers:
152153 open vs closed domain eval (https://haystack.deepset.ai/tutorials/evaluation).
153154 """
154155
156+ outgoing_edges = 1
157+
155158 def __init__ (self ,
156159 skip_incorrect_retrieval : bool = True ,
157160 open_domain : bool = True ,
@@ -174,7 +177,6 @@ def __init__(self,
174177 - Large model for German only: "deepset/gbert-large-sts"
175178 :param debug: When True, a record of each sample and its evaluation will be stored in EvalAnswers.log
176179 """
177- self .outgoing_edges = 1
178180 self .log : List = []
179181 self .debug = debug
180182 self .skip_incorrect_retrieval = skip_incorrect_retrieval
@@ -203,14 +205,14 @@ def init_counts(self):
203205 self .top_1_sas = 0.0
204206 self .top_k_sas = 0.0
205207
206- def run (self , labels , answers , ** kwargs ):
208+ def run (self , labels : List [ Label ] , answers : List [ dict ], correct_retrieval : bool ): # type: ignore
207209 """Run this node on one sample and its labels"""
208210 self .query_count += 1
209211 predictions = answers
210- skip = self .skip_incorrect_retrieval and not kwargs . get ( " correct_retrieval" )
212+ skip = self .skip_incorrect_retrieval and not correct_retrieval
211213 if predictions and not skip :
212214 self .correct_retrieval_count += 1
213- multi_labels = get_label (labels , kwargs [ "node_id" ] )
215+ multi_labels = get_label (labels , self . name )
214216 # If this sample is impossible to answer and expects a no_answer response
215217 if multi_labels .no_answer :
216218 self .no_answer_count += 1
@@ -254,7 +256,7 @@ def run(self, labels, answers, **kwargs):
254256 self .top_k_em_count += top_k_em
255257 self .top_k_f1_sum += top_k_f1
256258 self .update_has_answer_metrics ()
257- return {** kwargs }, "output_1"
259+ return {}, "output_1"
258260
259261 def evaluate_extraction (self , gold_labels , predictions ):
260262 if self .open_domain :
0 commit comments