Skip to content

Commit ee7833d

Browse files
committed
WIP: BLANC
1 parent 7d1465f commit ee7833d

File tree

2 files changed

+167
-18
lines changed

2 files changed

+167
-18
lines changed

tests/test_score.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple
1+
from typing import List, Tuple, Union, Literal
22
import pytest
33
from hypothesis import given, assume
44
import hypothesis.strategies as st
@@ -41,3 +41,105 @@ def test_lea_canonical_examples(
4141
assert precision == pytest.approx(expected[0], rel=1e-2)
4242
assert recall == pytest.approx(expected[1], rel=1e-2)
4343
assert f1 == pytest.approx(expected[2], rel=1e-2)
44+
45+
46+
@pytest.mark.parametrize(
47+
"pred,ref,expected",
48+
[
49+
([["m1"]], [["m1"]], (1.0, 1.0, 1.0)),
50+
(
51+
[
52+
["m1"],
53+
["m2"],
54+
["m3"],
55+
["m4, m6"],
56+
["m5", "m12"],
57+
["m7", "m9", "m14"],
58+
["m8"],
59+
["m10"],
60+
["m11"],
61+
["m13"],
62+
],
63+
[
64+
["m1"],
65+
["m2"],
66+
["m3"],
67+
["m4"],
68+
["m5", "m12", "m14"],
69+
["m6"],
70+
["m7", "m9"],
71+
["m8"],
72+
["m10"],
73+
["m11"],
74+
["m13"],
75+
],
76+
("*", "*", 0.7078),
77+
),
78+
(
79+
[
80+
["0"],
81+
["1"],
82+
["2"],
83+
["3"],
84+
["4"],
85+
["5"],
86+
["6"],
87+
["7"],
88+
["8"],
89+
["9"],
90+
["10"],
91+
["11"],
92+
["12"],
93+
["13"],
94+
["14"],
95+
["15"],
96+
["16"],
97+
["17"],
98+
["18"],
99+
],
100+
[
101+
["0"],
102+
["1"],
103+
["2"],
104+
["3"],
105+
["4"],
106+
["5"],
107+
["6"],
108+
["7"],
109+
["8"],
110+
["9"],
111+
["10"],
112+
["11"],
113+
["12"],
114+
["13"],
115+
["14"],
116+
["15"],
117+
["16"],
118+
["17", "18"],
119+
],
120+
("*", "*", "0.4984"),
121+
),
122+
],
123+
)
124+
def test_blanc_canonical_examples(
125+
pred: List[List[str]],
126+
ref: List[List[str]],
127+
expected: Tuple[
128+
Union[float, Literal["*"]],
129+
Union[float, Literal["*"]],
130+
Union[float, Literal["*"]],
131+
],
132+
):
133+
pred_doc = CoreferenceDocument(
134+
list(flatten(pred)),
135+
[[Mention([mention], 0, 0) for mention in chain] for chain in pred],
136+
)
137+
ref_doc = CoreferenceDocument(
138+
list(flatten(ref)),
139+
[[Mention([mention], 0, 0) for mention in chain] for chain in ref],
140+
)
141+
142+
precision, recall, f1 = score_lea([pred_doc], [ref_doc])
143+
assert expected[0] == "*" or precision == pytest.approx(expected[0], rel=1e-2)
144+
assert expected[1] == "*" or recall == pytest.approx(expected[1], rel=1e-2)
145+
assert expected[2] == "*" or f1 == pytest.approx(expected[2], rel=1e-2)

tibert/score.py

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import itertools as it
1313
from statistics import mean
1414
import numpy as np
15-
from neleval.coref_metrics import muc, b_cubed, ceaf
15+
from neleval.coref_metrics import muc, b_cubed, ceaf, pairwise, pairwise_negative
1616
from tibert.utils import spans_indexs
1717

1818
if TYPE_CHECKING:
@@ -38,15 +38,7 @@ def _coref_doc_to_neleval_format(doc: CoreferenceDocument, max_span_size: int):
3838
return clusters
3939

4040

41-
def _neleval_precision_recall_f1(
42-
pred: CoreferenceDocument,
43-
ref: CoreferenceDocument,
44-
neleval_fn: Callable[
45-
[Dict[int, Set[str]], Dict[int, Set[str]]],
46-
Tuple[float, float, float, float],
47-
],
48-
) -> Tuple[float, float, float]:
49-
"""Get precision, recall and f1 for a predicted document from a neleval metrics."""
41+
def _max_span_size(pred: CoreferenceDocument, ref: CoreferenceDocument) -> int:
5042
try:
5143
pred_max_span_size = max(
5244
[
@@ -57,6 +49,7 @@ def _neleval_precision_recall_f1(
5749
)
5850
except ValueError:
5951
pred_max_span_size = 0
52+
6053
try:
6154
ref_max_span_size = max(
6255
[
@@ -67,8 +60,20 @@ def _neleval_precision_recall_f1(
6760
)
6861
except ValueError:
6962
ref_max_span_size = 0
70-
max_span_size = max(pred_max_span_size, ref_max_span_size)
71-
# TODO max_span_size
63+
64+
return max(pred_max_span_size, ref_max_span_size)
65+
66+
67+
def _neleval_precision_recall_f1(
68+
pred: CoreferenceDocument,
69+
ref: CoreferenceDocument,
70+
neleval_fn: Callable[
71+
[Dict[int, Set[str]], Dict[int, Set[str]]],
72+
Tuple[float, float, float, float],
73+
],
74+
) -> Tuple[float, float, float]:
75+
"""Get precision, recall and f1 for a predicted document from a neleval metrics."""
76+
max_span_size = _max_span_size(pred, ref)
7277
neleval_pred = _coref_doc_to_neleval_format(pred, max_span_size + 1)
7378
neleval_ref = _coref_doc_to_neleval_format(ref, max_span_size + 1)
7479

@@ -140,7 +145,7 @@ def score_b_cubed(
140145
np.int = int # type: ignore
141146
np.bool = bool # type: ignore
142147

143-
precisions, recalls, f1s = []
148+
precisions, recalls, f1s = [], [], []
144149
for pred, ref in zip(preds, refs):
145150
p, r, f1 = _neleval_precision_recall_f1(pred, ref, b_cubed)
146151
precisions.append(p)
@@ -170,7 +175,7 @@ def score_ceaf(
170175
np.int = int # type: ignore
171176
np.bool = bool # type: ignore
172177

173-
precisions, recalls, f1s = []
178+
precisions, recalls, f1s = [], [], []
174179
for pred, ref in zip(preds, refs):
175180
p, r, f1 = _neleval_precision_recall_f1(pred, ref, ceaf)
176181
precisions.append(p)
@@ -180,6 +185,37 @@ def score_ceaf(
180185
return mean(precisions), mean(recalls), mean(f1s)
181186

182187

188+
def score_blanc(
189+
preds: List[CoreferenceDocument], refs: List[CoreferenceDocument]
190+
) -> Tuple[float, float, float]:
191+
assert len(preds) > 0
192+
assert len(preds) == len(refs)
193+
194+
precisions, recalls, f1s = [], [], []
195+
196+
for pred, ref in zip(preds, refs):
197+
max_span_size = _max_span_size(pred, ref)
198+
neleval_pred = _coref_doc_to_neleval_format(pred, max_span_size + 1)
199+
neleval_ref = _coref_doc_to_neleval_format(ref, max_span_size + 1)
200+
201+
p_num, p_den, r_num, r_den = pairwise(neleval_ref, neleval_pred)
202+
np_num, np_den, nr_num, nr_den = pairwise_negative(neleval_ref, neleval_pred)
203+
204+
P_c = p_num / p_den
205+
P_n = np_num / np_den
206+
precisions.append((P_c + P_n) / 2.0)
207+
208+
R_c = r_num / r_den
209+
R_n = nr_num / nr_den
210+
recalls.append((R_c + R_n) / 2.0)
211+
212+
F_c = (2 * P_c * R_c) / (P_c + R_c)
213+
F_n = (2 * P_n * R_n) / (P_n + R_n)
214+
f1s.append((F_c + F_n) / 2.0)
215+
216+
return mean(precisions), mean(recalls), mean(f1s)
217+
218+
183219
def score_lea(
184220
preds: List[CoreferenceDocument], refs: List[CoreferenceDocument]
185221
) -> Tuple[float, float, float]:
@@ -247,11 +283,10 @@ def lea_res_score(entity: List[Mention], entities: List[List[Mention]]) -> float
247283
def score_coref_predictions(
248284
preds: List[CoreferenceDocument], refs: List[CoreferenceDocument]
249285
) -> Dict[
250-
Literal["MUC", "B3", "CEAF"],
286+
Literal["MUC", "B3", "CEAF", "BLANC", "LEA"],
251287
Dict[Literal["precision", "recall", "f1"], float],
252288
]:
253-
"""Score coreference prediction according to MUC, B3 and CEAF
254-
metrics
289+
"""Score coreference prediction according to MUC, B3, CEAF, BLANC and LEA
255290
256291
.. note::
257292
@@ -263,6 +298,8 @@ def score_coref_predictions(
263298
muc_precision, muc_recall, muc_f1 = score_muc(preds, refs)
264299
b3_precision, b3_recall, b3_f1 = score_b_cubed(preds, refs)
265300
ceaf_precision, ceaf_recall, ceaf_f1 = score_ceaf(preds, refs)
301+
blanc_precision, blanc_recall, blanc_f1 = score_blanc(preds, refs)
302+
lea_precision, lea_recall, lea_f1 = score_lea(preds, refs)
266303

267304
return {
268305
"MUC": {
@@ -280,6 +317,16 @@ def score_coref_predictions(
280317
"recall": ceaf_recall,
281318
"f1": ceaf_f1,
282319
},
320+
"BLANC": {
321+
"precision": blanc_precision,
322+
"recall": blanc_recall,
323+
"f1": blanc_f1,
324+
},
325+
"LEA": {
326+
"precision": lea_precision,
327+
"recall": lea_recall,
328+
"f1": lea_f1,
329+
},
283330
}
284331

285332

0 commit comments

Comments
 (0)