Skip to content

Commit 7d1465f

Browse files
committed
Add LEA metrics and allow separate computation of other metrics
1 parent a120554 commit 7d1465f

File tree

2 files changed

+284
-104
lines changed

2 files changed

+284
-104
lines changed

tests/test_score.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,43 @@
1-
from typing import List
1+
from typing import List, Tuple
2+
import pytest
23
from hypothesis import given, assume
34
import hypothesis.strategies as st
4-
from tibert import score_mention_detection, CoreferenceDocument
5+
from tibert.bertcoref import Mention
6+
from tibert import score_mention_detection, CoreferenceDocument, score_lea
57
from tests.strategies import coref_docs
8+
from more_itertools import flatten
69

710

811
@given(docs=st.lists(coref_docs(min_size=1, max_size=32), min_size=1, max_size=3))
912
def test_mention_score_perfect_when_same_docs(docs: List[CoreferenceDocument]):
1013
assume(all([len(doc.coref_chains) > 0 for doc in docs]))
1114
assert score_mention_detection(docs, docs) == (1.0, 1.0, 1.0)
15+
16+
17+
@pytest.mark.parametrize(
18+
"pred,ref,expected",
19+
[
20+
([["A"]], [["A"]], (1.0, 1.0, 1.0)),
21+
(
22+
[["A", "B"], ["C", "D"], ["F", "G", "H", "I"]],
23+
[["A", "B", "C"], ["D", "E", "F", "G"]],
24+
(0.333, 0.24, 0.2779),
25+
),
26+
],
27+
)
28+
def test_lea_canonical_examples(
29+
pred: List[List[str]], ref: List[List[str]], expected: Tuple[float, float, float]
30+
):
31+
pred_doc = CoreferenceDocument(
32+
list(flatten(pred)),
33+
[[Mention([mention], 0, 0) for mention in chain] for chain in pred],
34+
)
35+
ref_doc = CoreferenceDocument(
36+
list(flatten(ref)),
37+
[[Mention([mention], 0, 0) for mention in chain] for chain in ref],
38+
)
39+
40+
precision, recall, f1 = score_lea([pred_doc], [ref_doc])
41+
assert precision == pytest.approx(expected[0], rel=1e-2)
42+
assert recall == pytest.approx(expected[1], rel=1e-2)
43+
assert f1 == pytest.approx(expected[2], rel=1e-2)

0 commit comments

Comments
 (0)