Skip to content

Commit 5b0b338

Browse files
authored
fix: Ensure eval mode for TableReader model for predictions (#3743)
* Adding model.eval() calls to prediction functions in table reader * Add unit test to check if model is set in train mode that inference time prediction still works.
1 parent 659020f commit 5b0b338

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

haystack/nodes/reader/table.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def _predict_tapas(self, inputs: BatchEncoding, document: Document) -> Answer:
251251
string_table = orig_table.astype(str)
252252

253253
# Forward query and table through model and convert logits to predictions
254+
self.model.eval()
254255
with torch.inference_mode():
255256
outputs = self.model(**inputs)
256257

@@ -424,6 +425,7 @@ def _predict_tapas_scored(self, inputs: BatchEncoding, document: Document) -> Tu
424425
string_table = orig_table.astype(str)
425426

426427
# Forward pass through model
428+
self.model.eval()
427429
with torch.inference_mode():
428430
outputs = self.model.tapas(**inputs)
429431
table_score = self.model.classifier(outputs.pooler_output)
@@ -719,6 +721,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =
719721
padding=True,
720722
)
721723
row_inputs.to(self.devices[0])
724+
self.row_model.eval()
722725
with torch.inference_mode():
723726
row_outputs = self.row_model(**row_inputs)
724727
row_logits = row_outputs[0].detach().cpu().numpy()[:, 1]
@@ -733,6 +736,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =
733736
padding=True,
734737
)
735738
column_inputs.to(self.devices[0])
739+
self.column_model.eval()
736740
with torch.inference_mode():
737741
column_outputs = self.column_model(**column_inputs)
738742
column_logits = column_outputs[0].detach().cpu().numpy()[:, 1]

test/nodes/test_table_reader.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

33
import pandas as pd
4+
import torch
45
import pytest
56

67
from haystack.schema import Document, Answer
@@ -41,6 +42,7 @@ def table3():
4142
@pytest.mark.parametrize("table_reader_and_param", ["tapas_small", "rci", "tapas_scored"], indirect=True)
4243
def test_table_reader(table_reader_and_param, table1, table2):
4344
table_reader, param = table_reader_and_param
45+
4446
query = "When was Di Caprio born?"
4547
prediction = table_reader.predict(
4648
query=query,
@@ -72,6 +74,42 @@ def test_table_reader(table_reader_and_param, table1, table2):
7274
assert prediction["answers"][1].offsets_in_context[0].end == reference2[param]["end"]
7375

7476

77+
@pytest.mark.parametrize("table_reader_and_param", ["tapas_small", "rci", "tapas_scored"], indirect=True)
78+
def test_table_reader_train_mode(table_reader_and_param, table1, table2):
79+
table_reader, param = table_reader_and_param
80+
81+
# Set to deterministic seed
82+
old_seed = torch.seed()
83+
torch.manual_seed(0)
84+
85+
# Ensure that if model is put in train mode that predictions are not effected
86+
if param != "rci":
87+
table_reader.table_encoder.model.train()
88+
elif param == "rci":
89+
table_reader.row_model.train()
90+
table_reader.column_model.train()
91+
92+
query = "When was Di Caprio born?"
93+
prediction = table_reader.predict(
94+
query=query,
95+
documents=[Document(content=table1, content_type="table"), Document(content=table2, content_type="table")],
96+
)
97+
98+
# Check the second answer in the list
99+
reference2 = {
100+
"tapas_small": {"answer": "5 april 1980", "start": 7, "end": 8, "score": 0.86314},
101+
"rci": {"answer": "47", "start": 5, "end": 6, "score": -6.836},
102+
"tapas_scored": {"answer": "brad pitt", "start": 0, "end": 1, "score": 0.49078},
103+
}
104+
assert prediction["answers"][1].score == pytest.approx(reference2[param]["score"], rel=1e-3)
105+
assert prediction["answers"][1].answer == reference2[param]["answer"]
106+
assert prediction["answers"][1].offsets_in_context[0].start == reference2[param]["start"]
107+
assert prediction["answers"][1].offsets_in_context[0].end == reference2[param]["end"]
108+
109+
# Set back to old_seed
110+
torch.manual_seed(old_seed)
111+
112+
75113
@pytest.mark.parametrize("table_reader_and_param", ["tapas_small", "rci", "tapas_scored"], indirect=True)
76114
def test_table_reader_batch_single_query_single_doc_list(table_reader_and_param, table1, table2):
77115
table_reader, param = table_reader_and_param

0 commit comments

Comments
 (0)