Skip to content

Commit b05e06f

Browse files
committed
Add mention detection scoring. Report it and conll f1 during training
1 parent 6f4a4b4 commit b05e06f

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

tibert/score.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tibert.utils import spans_indexs
66

77
if TYPE_CHECKING:
8-
from tibert.bertcoref import CoreferenceDocument
8+
from tibert.bertcoref import CoreferenceDocument, Mention
99

1010

1111
def score_coref_predictions(
@@ -133,3 +133,57 @@ def precisions_recalls_f1s(
133133
"f1": mean(ceaf_f1s),
134134
},
135135
}
136+
137+
138+
def doc_mentions(doc: CoreferenceDocument) -> List[Mention]:
139+
return [mention for chain in doc.coref_chains for mention in chain]
140+
141+
142+
def score_mention_detection(
143+
preds: List[CoreferenceDocument], refs: List[CoreferenceDocument]
144+
) -> Tuple[float, float, float]:
145+
"""Compute mention detection precision, recall and F1.
146+
147+
:param preds: predictions
148+
:param refs: references
149+
150+
:return: ``(precision, recall, f1)``
151+
"""
152+
assert len(preds) > 0
153+
assert len(refs) > 0
154+
155+
precision_l = []
156+
recall_l = []
157+
f1_l = []
158+
159+
for pred, ref in zip(preds, refs):
160+
161+
pred_mentions = doc_mentions(pred)
162+
ref_mentions = doc_mentions(ref)
163+
164+
if len(pred_mentions) == 0:
165+
continue
166+
precision = len([m for m in pred_mentions if m in ref_mentions]) / len(
167+
pred_mentions
168+
)
169+
170+
if len(ref_mentions) == 0:
171+
continue
172+
recall = len([m for m in ref_mentions if m in pred_mentions]) / len(
173+
ref_mentions
174+
)
175+
176+
if precision + recall == 0:
177+
continue
178+
179+
f1 = 2 * (precision * recall) / (precision + recall)
180+
181+
precision_l.append(precision)
182+
recall_l.append(recall)
183+
f1_l.append(f1)
184+
185+
if len(f1_l) == 0:
186+
print("[warning] undefined F1 for all samples")
187+
return (0.0, 0.0, 0.0)
188+
189+
return (mean(precision_l), mean(recall_l), mean(f1_l))

tibert/train.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from more_itertools.recipes import flatten
55
import torch
66
from torch.utils.data.dataloader import DataLoader
7-
from transformers import BertTokenizerFast # type: ignore
7+
from transformers import BertTokenizerFast, CamembertTokenizerFast # type: ignore
88
from tqdm import tqdm
99
from tibert import (
1010
BertForCoreferenceResolution,
@@ -14,14 +14,15 @@
1414
split_coreference_document,
1515
DataCollatorForSpanClassification,
1616
score_coref_predictions,
17+
score_mention_detection,
1718
)
1819
from tibert.utils import gpu_memory_usage
1920

2021

2122
def train_coref_model(
2223
model: Union[BertForCoreferenceResolution, CamembertForCoreferenceResolution],
2324
dataset: CoreferenceDataset,
24-
tokenizer: BertTokenizerFast,
25+
tokenizer: Union[BertTokenizerFast, CamembertTokenizerFast],
2526
batch_size: int = 1,
2627
epochs_nb: int = 30,
2728
sents_per_documents_train: int = 11,
@@ -150,8 +151,11 @@ def train_coref_model(
150151
)[0]
151152
for doc in test_dataset.documents
152153
]
153-
metrics = score_coref_predictions(preds, refs)
154154

155+
metrics = score_coref_predictions(preds, refs)
156+
conll_f1 = mean(
157+
[metrics["MUC"]["f1"], metrics["B3"]["f1"], metrics["CEAF"]["f1"]]
158+
)
155159
if _run:
156160
_run.log_scalar("muc_precision", metrics["MUC"]["precision"])
157161
_run.log_scalar("muc_recall", metrics["MUC"]["recall"])
@@ -162,23 +166,27 @@ def train_coref_model(
162166
_run.log_scalar("ceaf_precision", metrics["CEAF"]["precision"])
163167
_run.log_scalar("ceaf_recall", metrics["CEAF"]["recall"])
164168
_run.log_scalar("ceaf_f1", metrics["CEAF"]["f1"])
165-
169+
_run.log_scalar("conll_f1", conll_f1)
166170
print(metrics)
167171

168-
# keep the best model
169-
model_f1 = mean(
170-
[metrics["MUC"]["f1"], metrics["B3"]["f1"], metrics["CEAF"]["f1"]]
172+
m_precision, m_recall, m_f1 = score_mention_detection(preds, refs)
173+
if _run:
174+
_run.log_scalar("mention_detection_precision", m_precision)
175+
_run.log_scalar("mention_detection_recall", m_recall)
176+
_run.log_scalar("mention_detection_f1", m_f1)
177+
print(
178+
f"mention detection metrics: (precision: {m_precision}, recall: {m_recall}, f1: {m_f1})"
171179
)
172180

173181
except Exception as e:
174182
print(e)
175183
traceback.print_exc()
176-
model_f1 = 0
184+
conll_f1 = 0
177185

178-
if model_f1 > best_f1 or best_f1 == 0:
186+
if conll_f1 > best_f1 or best_f1 == 0:
179187
best_model = copy.deepcopy(model).to("cpu")
180188
if not model_save_path is None:
181189
best_model.save_pretrained(model_save_path)
182-
best_f1 = model_f1
190+
best_f1 = conll_f1
183191

184192
return best_model

0 commit comments

Comments
 (0)