|
1 | | -from typing import List |
| 1 | +from typing import List, Tuple |
| 2 | +import pytest |
2 | 3 | from hypothesis import given, assume |
3 | 4 | import hypothesis.strategies as st |
4 | | -from tibert import score_mention_detection, CoreferenceDocument |
| 5 | +from tibert.bertcoref import Mention |
| 6 | +from tibert import score_mention_detection, CoreferenceDocument, score_lea |
5 | 7 | from tests.strategies import coref_docs |
| 8 | +from more_itertools import flatten |
6 | 9 |
|
7 | 10 |
|
8 | 11 | @given(docs=st.lists(coref_docs(min_size=1, max_size=32), min_size=1, max_size=3)) |
9 | 12 | def test_mention_score_perfect_when_same_docs(docs: List[CoreferenceDocument]): |
10 | 13 | assume(all([len(doc.coref_chains) > 0 for doc in docs])) |
11 | 14 | assert score_mention_detection(docs, docs) == (1.0, 1.0, 1.0) |
| 15 | + |
| 16 | + |
| 17 | +@pytest.mark.parametrize( |
| 18 | + "pred,ref,expected", |
| 19 | + [ |
| 20 | + ([["A"]], [["A"]], (1.0, 1.0, 1.0)), |
| 21 | + ( |
| 22 | + [["A", "B"], ["C", "D"], ["F", "G", "H", "I"]], |
| 23 | + [["A", "B", "C"], ["D", "E", "F", "G"]], |
| 24 | + (0.333, 0.24, 0.2779), |
| 25 | + ), |
| 26 | + ], |
| 27 | +) |
| 28 | +def test_lea_canonical_examples( |
| 29 | + pred: List[List[str]], ref: List[List[str]], expected: Tuple[float, float, float] |
| 30 | +): |
| 31 | + pred_doc = CoreferenceDocument( |
| 32 | + list(flatten(pred)), |
| 33 | + [[Mention([mention], 0, 0) for mention in chain] for chain in pred], |
| 34 | + ) |
| 35 | + ref_doc = CoreferenceDocument( |
| 36 | + list(flatten(ref)), |
| 37 | + [[Mention([mention], 0, 0) for mention in chain] for chain in ref], |
| 38 | + ) |
| 39 | + |
| 40 | + precision, recall, f1 = score_lea([pred_doc], [ref_doc]) |
| 41 | + assert precision == pytest.approx(expected[0], rel=1e-2) |
| 42 | + assert recall == pytest.approx(expected[1], rel=1e-2) |
| 43 | + assert f1 == pytest.approx(expected[2], rel=1e-2) |
0 commit comments