Skip to content

Commit a2c160e

Browse files
authored
bug: skip empty documents in reader (#3773)
* skip empty documents * test eval_batch and account for tables
1 parent 43328d2 commit a2c160e

File tree

6 files changed

+142
-18
lines changed

6 files changed

+142
-18
lines changed

haystack/modeling/infer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,10 @@ def inference_from_objects(
508508
This parameter has no effect; it will be removed as Inferencer multiprocessing
509509
has been deprecated.
510510
"""
511+
# Return no predictions if there are no inputs
512+
if not objects:
513+
return []
514+
511515
dicts = [o.to_dict() for o in objects]
512516
# TODO investigate this deprecation warning. Timo: I thought we were about to implement Input Objects,
513517
# then we can and should use inference from (input) objects!

haystack/nodes/reader/base.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,25 @@ def _calc_no_answer(
4040
# the most significant difference between scores.
4141
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
4242
# No_ans_gap is a list of this most significant difference per document
43-
no_ans_gap_array = np.array(no_ans_gaps)
44-
max_no_ans_gap = np.max(no_ans_gap_array)
45-
# case 1: all passages "no answer" as top score
46-
# max_no_ans_gap is negative, so it increases best pos score
47-
# case 2: at least one passage predicts an answer (positive no_ans_gap)
48-
no_ans_score = best_score_answer - max_no_ans_gap
43+
44+
# If there is not even one predicted answer, we return a no_answer with score 1.0
45+
if best_score_answer == 0 and len(no_ans_gaps) == 0:
46+
no_ans_score = 1024.0
47+
no_ans_score_scaled = 1.0
48+
max_no_ans_gap = 1024.0
49+
else:
50+
no_ans_gap_array = np.array(no_ans_gaps)
51+
max_no_ans_gap = np.max(no_ans_gap_array)
52+
# case 1: all passages "no answer" as top score
53+
# max_no_ans_gap is negative, so it increases best pos score
54+
# case 2: at least one passage predicts an answer (positive no_ans_gap)
55+
no_ans_score = best_score_answer - max_no_ans_gap
56+
no_ans_score_scaled = float(expit(np.asarray(no_ans_score) / 8))
4957

5058
no_ans_prediction = Answer(
5159
answer="",
5260
type="extractive",
53-
score=float(expit(np.asarray(no_ans_score) / 8))
61+
score=no_ans_score_scaled
5462
if use_confidence_scores
5563
else no_ans_score, # just a pseudo prob for now or old score,
5664
context=None,
@@ -80,10 +88,27 @@ def add_doc_meta_data_to_answer(documents: List[Document], answer):
8088
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None, labels: Optional[MultiLabel] = None, add_isolated_node_eval: bool = False): # type: ignore
8189
self.query_count += 1
8290
predict = self.timing(self.predict, "query_time")
91+
# Remove empty text documents before making predictions
92+
documents = [d for d in documents if not isinstance(d.content, str) or d.content.strip() != ""]
8393
if documents:
8494
results = predict(query=query, documents=documents, top_k=top_k)
8595
else:
86-
results = {"answers": []}
96+
if hasattr(self, "return_no_answers") and self.return_no_answers:
97+
no_ans_prediction = Answer(
98+
answer="",
99+
type="extractive",
100+
score=1.0
101+
if hasattr(self, "use_confidence_scores") and self.use_confidence_scores
102+
else 1024.0, # just a pseudo prob for now or old score,
103+
context=None,
104+
offsets_in_context=[Span(start=0, end=0)],
105+
offsets_in_document=[Span(start=0, end=0)],
106+
document_id=None,
107+
meta=None,
108+
)
109+
results = {"answers": [no_ans_prediction]}
110+
else:
111+
results = {"answers": []}
87112

88113
# Add corresponding document_name and more meta data, if an answer contains the document_id
89114
results["answers"] = [
@@ -92,7 +117,9 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None
92117

93118
# run evaluation with labels as node inputs
94119
if add_isolated_node_eval and labels is not None:
95-
relevant_documents = {label.document.id: label.document for label in labels.labels}.values()
120+
relevant_documents = [label.document for label in labels.labels]
121+
# Filter out empty documents
122+
relevant_documents = [d for d in relevant_documents if d.content.strip() != ""]
96123
results_label_input = predict(query=query, documents=relevant_documents, top_k=top_k)
97124

98125
# Add corresponding document_name and more meta data, if an answer contains the document_id
@@ -113,6 +140,14 @@ def run_batch( # type: ignore
113140
add_isolated_node_eval: bool = False,
114141
):
115142
self.query_count += len(queries)
143+
144+
# Remove empty documents before making predictions
145+
if len(documents) > 0:
146+
if isinstance(documents[0], Document):
147+
documents = [d for d in documents if not isinstance(d.content, str) or d.content.strip() != ""] # type: ignore[union-attr, assignment]
148+
else:
149+
documents = [[d for d in docs_per_query if not isinstance(d.content, str) or d.content.strip() != ""] for docs_per_query in documents] # type: ignore[union-attr]
150+
116151
if not documents:
117152
return {"answers": []}, "output_1"
118153

@@ -138,7 +173,11 @@ def run_batch( # type: ignore
138173
if add_isolated_node_eval and labels is not None:
139174
relevant_documents = []
140175
for labelx in labels:
141-
relevant_documents.append([label.document for label in labelx.labels])
176+
# Filter out empty documents
177+
relevant_docs_labelx = [
178+
label.document for label in labelx.labels if label.document.content.strip() != ""
179+
]
180+
relevant_documents.append(relevant_docs_labelx)
142181
results_label_input = predict_batch(queries=queries, documents=relevant_documents, top_k=top_k)
143182

144183
# Add corresponding document_name and more meta data, if an answer contains the document_id

haystack/nodes/reader/farm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,6 @@ def predict_batch(
832832
# Group predictions together
833833
grouped_predictions = []
834834
left_idx = 0
835-
right_idx = 0
836835
for number in number_of_docs:
837836
right_idx = left_idx + number
838837
grouped_predictions.append(predictions[left_idx:right_idx])

haystack/nodes/reader/transformers.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,6 @@ def predict_batch(
233233
grouped_predictions = []
234234
grouped_inputs = []
235235
left_idx = 0
236-
right_idx = 0
237236
for number in number_of_docs:
238237
right_idx = left_idx + number
239238
grouped_predictions.append(predictions[left_idx:right_idx])
@@ -247,7 +246,9 @@ def predict_batch(
247246
for pred in preds_for_single_doc:
248247
cur_doc_id = inp.doc_id
249248
pred["doc_id"] = cur_doc_id
250-
if isinstance(grouped_pred[0], list):
249+
if len(grouped_pred) == 0:
250+
group = []
251+
elif isinstance(grouped_pred[0], list):
251252
group = list(itertools.chain.from_iterable(grouped_pred))
252253
answers, max_no_ans_gap = self._extract_answers_of_predictions(group, all_docs, top_k)
253254
results["answers"].append(answers)
@@ -271,8 +272,9 @@ def _extract_answers_of_predictions(
271272
no_ans_gaps = []
272273
best_overall_score = 0
273274

274-
cur_doc_id = predictions[0]["doc_id"]
275-
cur_doc = docs[cur_doc_id]
275+
if len(predictions) > 0:
276+
cur_doc_id = predictions[0]["doc_id"]
277+
cur_doc = docs[cur_doc_id]
276278
no_ans_doc_score = 0
277279
best_doc_score = 0
278280

@@ -313,7 +315,9 @@ def _extract_answers_of_predictions(
313315
# + add no_ans_gap for last Document
314316
if best_doc_score > best_overall_score:
315317
best_overall_score = best_doc_score
316-
no_ans_gaps.append(no_ans_doc_score - best_doc_score)
318+
319+
if len(predictions) > 0:
320+
no_ans_gaps.append(no_ans_doc_score - best_doc_score)
317321

318322
# Calculate the score for predicting "no answer", relative to our best positive answer score
319323
no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps, best_overall_score)

test/nodes/test_reader.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from huggingface_hub import snapshot_download
99
from haystack.modeling.data_handler.inputs import QAInput, Question
1010

11-
from haystack.schema import Document, Answer
11+
from haystack.schema import Document, Answer, Label, MultiLabel, Span
1212
from haystack.nodes.reader.base import BaseReader
1313
from haystack.nodes import FARMReader, TransformersReader
1414

@@ -32,6 +32,7 @@ def no_answer_reader(request):
3232
tokenizer="deepset/bert-medium-squad2-distilled",
3333
use_gpu=-1,
3434
top_k_per_candidate=5,
35+
return_no_answers=True,
3536
)
3637

3738

@@ -175,7 +176,6 @@ def test_context_window_size(reader, docs, window_size):
175176
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
176177
@pytest.mark.parametrize("top_k", [2, 5, 10])
177178
def test_top_k(reader, docs, top_k):
178-
179179
assert isinstance(reader, FARMReader)
180180

181181
old_top_k_per_candidate = reader.top_k_per_candidate
@@ -352,3 +352,56 @@ def test_farm_reader_onnx_conversion_and_inference(model_name, tmpdir, docs):
352352
reader = FARMReader(str(Path(tmpdir, "onnx")))
353353
result = reader.predict(query="Where does Paul live?", documents=[docs[0]])
354354
assert result["answers"][0].answer == "New York"
355+
356+
357+
LABELS = [
358+
MultiLabel(
359+
labels=[
360+
Label(
361+
query="Who lives in Berlin?",
362+
answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
363+
document=Document(
364+
id="a0747b83aea0b60c4b114b15476dd32d", content_type="text", content="" # empty document
365+
),
366+
is_correct_answer=True,
367+
is_correct_document=True,
368+
origin="gold-label",
369+
)
370+
]
371+
),
372+
MultiLabel(
373+
labels=[
374+
Label(
375+
query="Who lives in Munich?",
376+
answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
377+
document=Document(
378+
id="something_else", content_type="text", content="My name is Carla and I live in Munich"
379+
),
380+
is_correct_answer=True,
381+
is_correct_document=True,
382+
origin="gold-label",
383+
)
384+
]
385+
),
386+
]
387+
388+
389+
def test_reader_skips_empty_documents(reader):
390+
predictions, _ = reader.run(query=LABELS[0].labels[0].query, documents=[LABELS[0].labels[0].document])
391+
assert predictions["answers"] == [] # no answer given for query as document is empty
392+
predictions, _ = reader.run_batch(
393+
queries=[l.labels[0].query for l in LABELS], documents=[[l.labels[0].document] for l in LABELS]
394+
)
395+
assert predictions["answers"][0] == [] # no answer given for 1st query as document is empty
396+
assert predictions["answers"][1][0].answer == "Carla" # answer given for 2nd query as usual
397+
398+
399+
@pytest.mark.parametrize("no_answer_reader", ["farm", "transformers"], indirect=True)
400+
def test_no_answer_reader_skips_empty_documents(no_answer_reader):
401+
predictions, _ = no_answer_reader.run(query=LABELS[0].labels[0].query, documents=[LABELS[0].labels[0].document])
402+
assert predictions["answers"][0].answer == "" # Return no_answer as document is empty
403+
predictions, _ = no_answer_reader.run_batch(
404+
queries=[l.labels[0].query for l in LABELS], documents=[[l.labels[0].document] for l in LABELS]
405+
)
406+
assert predictions["answers"][0][0].answer == "" # Return no_answer for 1st query as document is empty
407+
assert predictions["answers"][1][1].answer == "Carla" # answer given for 2nd query as usual

test/pipelines/test_eval.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,3 +1411,28 @@ def test_multi_retriever_pipeline_with_asymmetric_qa_eval(document_store_with_do
14111411

14121412
assert metrics["QAReader"]["exact_match"] == 1.0
14131413
assert metrics["QAReader"]["f1"] == 1.0
1414+
1415+
1416+
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
1417+
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
1418+
@pytest.mark.parametrize("reader", ["farm", "transformers"], indirect=True)
1419+
def test_empty_documents_dont_fail_pipeline(reader, retriever_with_docs):
1420+
multilabels = EVAL_LABELS[:2]
1421+
multilabels[0].labels[0].document.content = ""
1422+
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
1423+
eval_result_integrated: EvaluationResult = pipeline.eval(labels=multilabels, add_isolated_node_eval=False)
1424+
assert eval_result_integrated["Reader"]["answer"].iloc[0] == "Carla"
1425+
eval_result_iso: EvaluationResult = pipeline.eval(labels=multilabels, add_isolated_node_eval=True)
1426+
assert eval_result_iso["Reader"].loc[eval_result_iso["Reader"]["eval_mode"] == "isolated"]["answer"].iloc[0] == ""
1427+
1428+
eval_batch_result_integrated: EvaluationResult = pipeline.eval_batch(
1429+
labels=multilabels, add_isolated_node_eval=False
1430+
)
1431+
assert eval_batch_result_integrated["Reader"]["answer"].iloc[0] == "Carla"
1432+
eval_batch_result_iso: EvaluationResult = pipeline.eval_batch(labels=multilabels, add_isolated_node_eval=True)
1433+
assert (
1434+
eval_batch_result_iso["Reader"]
1435+
.loc[eval_batch_result_iso["Reader"]["eval_mode"] == "isolated"]["answer"]
1436+
.iloc[0]
1437+
== ""
1438+
)

0 commit comments

Comments
 (0)