Skip to content

Commit 76a12a0

Browse files
committed
Add BLANC
1 parent ee7833d commit 76a12a0

File tree

2 files changed

+98
-80
lines changed

2 files changed

+98
-80
lines changed

tests/test_score.py

Lines changed: 75 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from tests.strategies import coref_docs
88
from more_itertools import flatten
99

10+
from tibert.score import score_blanc
11+
1012

1113
@given(docs=st.lists(coref_docs(min_size=1, max_size=32), min_size=1, max_size=3))
1214
def test_mention_score_perfect_when_same_docs(docs: List[CoreferenceDocument]):
@@ -46,100 +48,106 @@ def test_lea_canonical_examples(
4648
@pytest.mark.parametrize(
4749
"pred,ref,expected",
4850
[
49-
([["m1"]], [["m1"]], (1.0, 1.0, 1.0)),
51+
([[0]], [[0]], (1.0, 1.0, 1.0)),
5052
(
5153
[
52-
["m1"],
53-
["m2"],
54-
["m3"],
55-
["m4, m6"],
56-
["m5", "m12"],
57-
["m7", "m9", "m14"],
58-
["m8"],
59-
["m10"],
60-
["m11"],
61-
["m13"],
54+
[0],
55+
[1],
56+
[2],
57+
[3, 5],
58+
[4, 11],
59+
[6, 8, 13],
60+
[7],
61+
[9],
62+
[10],
63+
[12],
6264
],
6365
[
64-
["m1"],
65-
["m2"],
66-
["m3"],
67-
["m4"],
68-
["m5", "m12", "m14"],
69-
["m6"],
70-
["m7", "m9"],
71-
["m8"],
72-
["m10"],
73-
["m11"],
74-
["m13"],
66+
[0],
67+
[1],
68+
[2],
69+
[3],
70+
[4, 11, 13],
71+
[5],
72+
[6, 8],
73+
[7],
74+
[9],
75+
[10],
76+
[12],
7577
],
7678
("*", "*", 0.7078),
7779
),
7880
(
7981
[
80-
["0"],
81-
["1"],
82-
["2"],
83-
["3"],
84-
["4"],
85-
["5"],
86-
["6"],
87-
["7"],
88-
["8"],
89-
["9"],
90-
["10"],
91-
["11"],
92-
["12"],
93-
["13"],
94-
["14"],
95-
["15"],
96-
["16"],
97-
["17"],
98-
["18"],
82+
[0],
83+
[1],
84+
[2],
85+
[3],
86+
[4],
87+
[5],
88+
[6],
89+
[7],
90+
[8],
91+
[9],
92+
[10],
93+
[11],
94+
[12],
95+
[13],
96+
[14],
97+
[15],
98+
[16],
99+
[17],
100+
[18],
99101
],
100102
[
101-
["0"],
102-
["1"],
103-
["2"],
104-
["3"],
105-
["4"],
106-
["5"],
107-
["6"],
108-
["7"],
109-
["8"],
110-
["9"],
111-
["10"],
112-
["11"],
113-
["12"],
114-
["13"],
115-
["14"],
116-
["15"],
117-
["16"],
118-
["17", "18"],
103+
[0],
104+
[1],
105+
[2],
106+
[3],
107+
[4],
108+
[5],
109+
[6],
110+
[7],
111+
[8],
112+
[9],
113+
[10],
114+
[11],
115+
[12],
116+
[13],
117+
[14],
118+
[15],
119+
[16],
120+
[17, 18],
119121
],
120-
("*", "*", "0.4984"),
122+
("*", "*", 0.4984),
121123
),
122124
],
123125
)
124126
def test_blanc_canonical_examples(
125-
pred: List[List[str]],
126-
ref: List[List[str]],
127+
pred: List[List[int]],
128+
ref: List[List[int]],
127129
expected: Tuple[
128130
Union[float, Literal["*"]],
129131
Union[float, Literal["*"]],
130132
Union[float, Literal["*"]],
131133
],
132134
):
133135
pred_doc = CoreferenceDocument(
134-
list(flatten(pred)),
135-
[[Mention([mention], 0, 0) for mention in chain] for chain in pred],
136+
[str(m) for m in flatten(pred)],
137+
[
138+
[Mention([str(mention)], mention, mention + 1) for mention in chain]
139+
for chain in pred
140+
],
136141
)
137142
ref_doc = CoreferenceDocument(
138-
list(flatten(ref)),
139-
[[Mention([mention], 0, 0) for mention in chain] for chain in ref],
143+
[str(m) for m in flatten(ref)],
144+
[
145+
[Mention([str(mention)], mention, mention + 1) for mention in chain]
146+
for chain in ref
147+
],
140148
)
141149

142-
precision, recall, f1 = score_lea([pred_doc], [ref_doc])
150+
precision, recall, f1 = score_blanc([pred_doc], [ref_doc])
143151
assert expected[0] == "*" or precision == pytest.approx(expected[0], rel=1e-2)
144152
assert expected[1] == "*" or recall == pytest.approx(expected[1], rel=1e-2)
145153
assert expected[2] == "*" or f1 == pytest.approx(expected[2], rel=1e-2)

tibert/score.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -191,29 +191,41 @@ def score_blanc(
191191
assert len(preds) > 0
192192
assert len(preds) == len(refs)
193193

194-
precisions, recalls, f1s = [], [], []
194+
prf = []
195195

196196
for pred, ref in zip(preds, refs):
197+
if pred.coref_chains == ref.coref_chains:
198+
prf.append((1, 1, 1))
199+
continue
200+
197201
max_span_size = _max_span_size(pred, ref)
198202
neleval_pred = _coref_doc_to_neleval_format(pred, max_span_size + 1)
199203
neleval_ref = _coref_doc_to_neleval_format(ref, max_span_size + 1)
200204

201205
p_num, p_den, r_num, r_den = pairwise(neleval_ref, neleval_pred)
202206
np_num, np_den, nr_num, nr_den = pairwise_negative(neleval_ref, neleval_pred)
203207

204-
P_c = p_num / p_den
205-
P_n = np_num / np_den
206-
precisions.append((P_c + P_n) / 2.0)
208+
# pred_has_one_entity = len(pred.coref_chains) == 1
209+
# pred_has_only_singletons = all([len(chain) == 1 for chain in pred.coref_chains])
210+
# ref_has_one_entity = len(ref.coref_chains) == 1
211+
# ref_has_only_singletons = all([len(chain) == 1 for chain in ref.coref_chains])
207212

208-
R_c = r_num / r_den
209-
R_n = nr_num / nr_den
210-
recalls.append((R_c + R_n) / 2.0)
213+
P_c = 0 if p_den == 0 else p_num / p_den
214+
P_n = 0 if np_den == 0 else np_num / np_den
211215

212-
F_c = (2 * P_c * R_c) / (P_c + R_c)
213-
F_n = (2 * P_n * R_n) / (P_n + R_n)
214-
f1s.append((F_c + F_n) / 2.0)
216+
R_c = 0 if r_den == 0 else r_num / r_den
217+
R_n = 0 if nr_den == 0 else nr_num / nr_den
215218

216-
return mean(precisions), mean(recalls), mean(f1s)
219+
F_c = 0 if P_c + R_c == 0 else (2 * P_c * R_c) / (P_c + R_c)
220+
F_n = 0 if P_n + R_n == 0 else (2 * P_n * R_n) / (P_n + R_n)
221+
222+
prf.append(((P_c + P_n) / 2.0, (R_c + R_n) / 2.0, (F_c + F_n) / 2.0))
223+
224+
return (
225+
mean([m[0] for m in prf]),
226+
mean([m[1] for m in prf]),
227+
mean([m[2] for m in prf]),
228+
)
217229

218230

219231
def score_lea(
@@ -251,7 +263,6 @@ def lea_res_score(entity: List[Mention], entities: List[List[Mention]]) -> float
251263
precisions, recalls, f1s = [], [], []
252264

253265
for pred, ref in zip(preds, refs):
254-
255266
precision_num = 0
256267
precision_den = 0
257268
for pred_chain in pred.coref_chains:
@@ -352,7 +363,6 @@ def score_mention_detection(
352363
f1_l = []
353364

354365
for pred, ref in zip(preds, refs):
355-
356366
pred_mentions = doc_mentions(pred)
357367
ref_mentions = doc_mentions(ref)
358368

0 commit comments

Comments
 (0)