|
1 | 1 | import logging |
2 | 2 |
|
3 | 3 | import pandas as pd |
| 4 | +import torch |
4 | 5 | import pytest |
5 | 6 |
|
6 | 7 | from haystack.schema import Document, Answer |
@@ -41,6 +42,7 @@ def table3(): |
41 | 42 | @pytest.mark.parametrize("table_reader_and_param", ["tapas_small", "rci", "tapas_scored"], indirect=True) |
42 | 43 | def test_table_reader(table_reader_and_param, table1, table2): |
43 | 44 | table_reader, param = table_reader_and_param |
| 45 | + |
44 | 46 | query = "When was Di Caprio born?" |
45 | 47 | prediction = table_reader.predict( |
46 | 48 | query=query, |
@@ -72,6 +74,42 @@ def test_table_reader(table_reader_and_param, table1, table2): |
72 | 74 | assert prediction["answers"][1].offsets_in_context[0].end == reference2[param]["end"] |
73 | 75 |
|
74 | 76 |
|
| 77 | +@pytest.mark.parametrize("table_reader_and_param", ["tapas_small", "rci", "tapas_scored"], indirect=True) |
| 78 | +def test_table_reader_train_mode(table_reader_and_param, table1, table2): |
| 79 | + table_reader, param = table_reader_and_param |
| 80 | + |
| 81 | + # Set to deterministic seed |
| 82 | + old_seed = torch.seed() |
| 83 | + torch.manual_seed(0) |
| 84 | + |
| 85 | + # Ensure that if model is put in train mode that predictions are not effected |
| 86 | + if param != "rci": |
| 87 | + table_reader.table_encoder.model.train() |
| 88 | + elif param == "rci": |
| 89 | + table_reader.row_model.train() |
| 90 | + table_reader.column_model.train() |
| 91 | + |
| 92 | + query = "When was Di Caprio born?" |
| 93 | + prediction = table_reader.predict( |
| 94 | + query=query, |
| 95 | + documents=[Document(content=table1, content_type="table"), Document(content=table2, content_type="table")], |
| 96 | + ) |
| 97 | + |
| 98 | + # Check the second answer in the list |
| 99 | + reference2 = { |
| 100 | + "tapas_small": {"answer": "5 april 1980", "start": 7, "end": 8, "score": 0.86314}, |
| 101 | + "rci": {"answer": "47", "start": 5, "end": 6, "score": -6.836}, |
| 102 | + "tapas_scored": {"answer": "brad pitt", "start": 0, "end": 1, "score": 0.49078}, |
| 103 | + } |
| 104 | + assert prediction["answers"][1].score == pytest.approx(reference2[param]["score"], rel=1e-3) |
| 105 | + assert prediction["answers"][1].answer == reference2[param]["answer"] |
| 106 | + assert prediction["answers"][1].offsets_in_context[0].start == reference2[param]["start"] |
| 107 | + assert prediction["answers"][1].offsets_in_context[0].end == reference2[param]["end"] |
| 108 | + |
| 109 | + # Set back to old_seed |
| 110 | + torch.manual_seed(old_seed) |
| 111 | + |
| 112 | + |
75 | 113 | @pytest.mark.parametrize("table_reader_and_param", ["tapas_small", "rci", "tapas_scored"], indirect=True) |
76 | 114 | def test_table_reader_batch_single_query_single_doc_list(table_reader_and_param, table1, table2): |
77 | 115 | table_reader, param = table_reader_and_param |
|
0 commit comments