1212import itertools as it
1313from statistics import mean
1414import numpy as np
15- from neleval .coref_metrics import muc , b_cubed , ceaf
15+ from neleval .coref_metrics import muc , b_cubed , ceaf , pairwise , pairwise_negative
1616from tibert .utils import spans_indexs
1717
1818if TYPE_CHECKING :
@@ -38,15 +38,7 @@ def _coref_doc_to_neleval_format(doc: CoreferenceDocument, max_span_size: int):
3838 return clusters
3939
4040
41- def _neleval_precision_recall_f1 (
42- pred : CoreferenceDocument ,
43- ref : CoreferenceDocument ,
44- neleval_fn : Callable [
45- [Dict [int , Set [str ]], Dict [int , Set [str ]]],
46- Tuple [float , float , float , float ],
47- ],
48- ) -> Tuple [float , float , float ]:
49- """Get precision, recall and f1 for a predicted document from a neleval metrics."""
41+ def _max_span_size (pred : CoreferenceDocument , ref : CoreferenceDocument ) -> int :
5042 try :
5143 pred_max_span_size = max (
5244 [
@@ -57,6 +49,7 @@ def _neleval_precision_recall_f1(
5749 )
5850 except ValueError :
5951 pred_max_span_size = 0
52+
6053 try :
6154 ref_max_span_size = max (
6255 [
@@ -67,8 +60,20 @@ def _neleval_precision_recall_f1(
6760 )
6861 except ValueError :
6962 ref_max_span_size = 0
70- max_span_size = max (pred_max_span_size , ref_max_span_size )
71- # TODO max_span_size
63+
64+ return max (pred_max_span_size , ref_max_span_size )
65+
66+
67+ def _neleval_precision_recall_f1 (
68+ pred : CoreferenceDocument ,
69+ ref : CoreferenceDocument ,
70+ neleval_fn : Callable [
71+ [Dict [int , Set [str ]], Dict [int , Set [str ]]],
72+ Tuple [float , float , float , float ],
73+ ],
74+ ) -> Tuple [float , float , float ]:
75+ """Get precision, recall and f1 for a predicted document from a neleval metrics."""
76+ max_span_size = _max_span_size (pred , ref )
7277 neleval_pred = _coref_doc_to_neleval_format (pred , max_span_size + 1 )
7378 neleval_ref = _coref_doc_to_neleval_format (ref , max_span_size + 1 )
7479
@@ -140,7 +145,7 @@ def score_b_cubed(
140145 np .int = int # type: ignore
141146 np .bool = bool # type: ignore
142147
143- precisions , recalls , f1s = []
148+ precisions , recalls , f1s = [], [], []
144149 for pred , ref in zip (preds , refs ):
145150 p , r , f1 = _neleval_precision_recall_f1 (pred , ref , b_cubed )
146151 precisions .append (p )
@@ -170,7 +175,7 @@ def score_ceaf(
170175 np .int = int # type: ignore
171176 np .bool = bool # type: ignore
172177
173- precisions , recalls , f1s = []
178+ precisions , recalls , f1s = [], [], []
174179 for pred , ref in zip (preds , refs ):
175180 p , r , f1 = _neleval_precision_recall_f1 (pred , ref , ceaf )
176181 precisions .append (p )
@@ -180,6 +185,37 @@ def score_ceaf(
180185 return mean (precisions ), mean (recalls ), mean (f1s )
181186
182187
188+ def score_blanc (
189+ preds : List [CoreferenceDocument ], refs : List [CoreferenceDocument ]
190+ ) -> Tuple [float , float , float ]:
191+ assert len (preds ) > 0
192+ assert len (preds ) == len (refs )
193+
194+ precisions , recalls , f1s = [], [], []
195+
196+ for pred , ref in zip (preds , refs ):
197+ max_span_size = _max_span_size (pred , ref )
198+ neleval_pred = _coref_doc_to_neleval_format (pred , max_span_size + 1 )
199+ neleval_ref = _coref_doc_to_neleval_format (ref , max_span_size + 1 )
200+
201+ p_num , p_den , r_num , r_den = pairwise (neleval_ref , neleval_pred )
202+ np_num , np_den , nr_num , nr_den = pairwise_negative (neleval_ref , neleval_pred )
203+
204+ P_c = p_num / p_den
205+ P_n = np_num / np_den
206+ precisions .append ((P_c + P_n ) / 2.0 )
207+
208+ R_c = r_num / r_den
209+ R_n = nr_num / nr_den
210+ recalls .append ((R_c + R_n ) / 2.0 )
211+
212+ F_c = (2 * P_c * R_c ) / (P_c + R_c )
213+ F_n = (2 * P_n * R_n ) / (P_n + R_n )
214+ f1s .append ((F_c + F_n ) / 2.0 )
215+
216+ return mean (precisions ), mean (recalls ), mean (f1s )
217+
218+
183219def score_lea (
184220 preds : List [CoreferenceDocument ], refs : List [CoreferenceDocument ]
185221) -> Tuple [float , float , float ]:
@@ -247,11 +283,10 @@ def lea_res_score(entity: List[Mention], entities: List[List[Mention]]) -> float
247283def score_coref_predictions (
248284 preds : List [CoreferenceDocument ], refs : List [CoreferenceDocument ]
249285) -> Dict [
250- Literal ["MUC" , "B3" , "CEAF" ],
286+ Literal ["MUC" , "B3" , "CEAF" , "BLANC" , "LEA" ],
251287 Dict [Literal ["precision" , "recall" , "f1" ], float ],
252288]:
253- """Score coreference prediction according to MUC, B3 and CEAF
254- metrics
289+ """Score coreference prediction according to MUC, B3, CEAF, BLANC and LEA
255290
256291 .. note::
257292
@@ -263,6 +298,8 @@ def score_coref_predictions(
263298 muc_precision , muc_recall , muc_f1 = score_muc (preds , refs )
264299 b3_precision , b3_recall , b3_f1 = score_b_cubed (preds , refs )
265300 ceaf_precision , ceaf_recall , ceaf_f1 = score_ceaf (preds , refs )
301+ blanc_precision , blanc_recall , blanc_f1 = score_blanc (preds , refs )
302+ lea_precision , lea_recall , lea_f1 = score_lea (preds , refs )
266303
267304 return {
268305 "MUC" : {
@@ -280,6 +317,16 @@ def score_coref_predictions(
280317 "recall" : ceaf_recall ,
281318 "f1" : ceaf_f1 ,
282319 },
320+ "BLANC" : {
321+ "precision" : blanc_precision ,
322+ "recall" : blanc_recall ,
323+ "f1" : blanc_f1 ,
324+ },
325+ "LEA" : {
326+ "precision" : lea_precision ,
327+ "recall" : lea_recall ,
328+ "f1" : lea_f1 ,
329+ },
283330 }
284331
285332
0 commit comments