Skip to content

Commit a1c97fe

Browse files
committed
add mention detection metrics reporting for train/test scripts
1 parent 82fd0d0 commit a1c97fe

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

tibert/run_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
BertForCoreferenceResolution,
1616
CamembertForCoreferenceResolution,
1717
)
18-
from tibert.score import score_coref_predictions
18+
from tibert.score import score_coref_predictions, score_mention_detection
1919
from tibert.predict import predict_coref
2020
from tibert.utils import split_coreference_document_tokens
2121

@@ -82,7 +82,7 @@ def main(
8282
all_annotated_docs = []
8383
for document in tqdm(test_dataset.documents):
8484
doc_dataset = CoreferenceDataset(
85-
split_coreference_document_tokens(document, 512),
85+
[document],
8686
tokenizer,
8787
max_span_size,
8888
)
@@ -110,6 +110,17 @@ def main(
110110
annotated_doc = CoreferenceDocument.concatenated(annotated_docs)
111111
all_annotated_docs.append(annotated_doc)
112112

113+
mention_pre, mention_rec, mention_f1 = score_mention_detection(
114+
all_annotated_docs, test_dataset.documents
115+
)
116+
for metric_key, score in [
117+
("precision", mention_pre),
118+
("recall", mention_rec),
119+
("f1", mention_f1),
120+
]:
121+
print(f"mention.{metric_key}={score}")
122+
_run.log_scalar(f"mention.{metric_key}", score)
123+
113124
scores = score_coref_predictions(all_annotated_docs, test_dataset.documents)
114125
for key, score_dict in scores.items():
115126
for metric_key, score in score_dict.items():

tibert/run_train.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
load_train_checkpoint,
1515
predict_coref,
1616
score_coref_predictions,
17+
score_mention_detection,
1718
)
1819
from tibert.bertcoref import CoreferenceDataset, load_democrat_dataset
1920

@@ -145,6 +146,17 @@ def main(
145146
)
146147
assert isinstance(annotated_docs, list)
147148

149+
mention_pre, mention_rec, mention_f1 = score_mention_detection(
150+
annotated_docs, test_dataset.documents
151+
)
152+
for metric_key, score in [
153+
("precision", mention_pre),
154+
("recall", mention_rec),
155+
("f1", mention_f1),
156+
]:
157+
print(f"mention.{metric_key}={score}")
158+
_run.log_scalar(f"mention.{metric_key}", score)
159+
148160
metrics = score_coref_predictions(annotated_docs, test_dataset.documents)
149161
print(metrics)
150162
for key, score_dict in metrics.items():

0 commit comments

Comments
 (0)